wsntxxn commited on
Commit
b4bbb92
·
0 Parent(s):

Init commit

Browse files
.gitattributes ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data/egs/*.wav filter=lfs diff=lfs merge=lfs -text
37
+ data/egs/*.mp4 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
+ .pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ #.idea/
169
+
170
+ # PyPI configuration file
171
+ .pypirc
172
+
173
+ # Experiment logs and checkpoints
174
+ experiments/
175
+ checkpoints/
176
+ ckpts/
177
+ evaluation/result/
178
+
179
+ # Raw / processed data file
180
+
181
+ # Docker build files
182
+ docker/
183
+
184
+ # VC logs
185
+ logs/
186
+ *log
187
+
188
+ # VS Code settings
189
+ .vscode/
190
+ pyrightconfig.json
191
+
192
+ # Demo files
193
+ demo/
194
+
195
+ # Binary data files
196
+ **/*.h5
197
+
198
+ test*
199
+
200
+ start.sh
201
+
202
+ # macOS sys files
203
+ .DS_Store
204
+ .AppleDouble
205
+ .LSOverride
206
+ ._*
207
+ .Spotlight-V100
208
+ .Trashes
209
+ .DocumentRevisions-V100
210
+ .fseventsd
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: UniFlow Audio
3
+ emoji: 👁
4
+ colorFrom: pink
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 5.49.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,655 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import os
4
+ import gradio as gr
5
+
6
+ import spaces
7
+
8
+ from inference_cli import InferenceCLI
9
+
10
+ # Initialize inference CLI
11
+ cli = InferenceCLI()
12
+
13
+ # Available model choices
14
+ MODEL_CHOICES = [
15
+ "UniFlow-Audio-large", "UniFlow-Audio-medium", "UniFlow-Audio-small"
16
+ ]
17
+
18
+ # Default model name
19
+ DEFAULT_MODEL = "UniFlow-Audio-large"
20
+
21
+ # Pre-initialize models
22
+ print("Initializing models, please wait...")
23
+ print(f"Loading main model: {DEFAULT_MODEL}")
24
+ cli.init_model(DEFAULT_MODEL)
25
+
26
+ print("Loading speaker model for TTS...")
27
+ cli.init_speaker_model()
28
+
29
+ print("Loading G2P model for TTS...")
30
+ from montreal_forced_aligner.g2p.generator import PyniniConsoleGenerator
31
+ if not cli.g2p:
32
+ cli.g2p = PyniniConsoleGenerator(
33
+ g2p_model_path=cli.model.g2p_model_path,
34
+ strict_graphemes=False,
35
+ num_pronunciations=1,
36
+ include_bracketed=False
37
+ )
38
+ cli.g2p.setup()
39
+
40
+ print("Loading SVS processor for singing voice synthesis...")
41
+ cli.init_svs_processor()
42
+
43
+ print("Loading video preprocessor for V2A...")
44
+ cli.init_video_preprocessor()
45
+
46
+ print("All models loaded successfully!")
47
+
48
+
49
+ @spaces.GPU(duration=60)
50
+ def text_to_audio(
51
+ caption,
52
+ model_name,
53
+ guidance_scale,
54
+ num_steps,
55
+ progress=gr.Progress(track_tqdm=True)
56
+ ):
57
+ """Text to Audio generation"""
58
+ output_path = "./outputs/t2a_output.wav"
59
+ os.makedirs("./outputs", exist_ok=True)
60
+
61
+ try:
62
+ cli.t2a(
63
+ caption=caption,
64
+ model_name=model_name,
65
+ guidance_scale=guidance_scale,
66
+ num_steps=num_steps,
67
+ output_path=output_path
68
+ )
69
+ return output_path, "Generation successful!"
70
+ except Exception as e:
71
+ return None, f"Error: {str(e)}"
72
+
73
+
74
+ @spaces.GPU(duration=60)
75
+ def text_to_music(
76
+ caption,
77
+ model_name,
78
+ guidance_scale,
79
+ num_steps,
80
+ progress=gr.Progress(track_tqdm=True)
81
+ ):
82
+ """Text to Music generation"""
83
+ output_path = "./outputs/t2m_output.wav"
84
+ os.makedirs("./outputs", exist_ok=True)
85
+
86
+ try:
87
+ cli.t2m(
88
+ caption=caption,
89
+ model_name=model_name,
90
+ guidance_scale=guidance_scale,
91
+ num_steps=num_steps,
92
+ output_path=output_path
93
+ )
94
+ return output_path, "Generation successful!"
95
+ except Exception as e:
96
+ return None, f"Error: {str(e)}"
97
+
98
+
99
+ @spaces.GPU(duration=60)
100
+ def text_to_speech(
101
+ transcript,
102
+ ref_speaker_audio,
103
+ model_name,
104
+ guidance_scale,
105
+ num_steps,
106
+ progress=gr.Progress(track_tqdm=True)
107
+ ):
108
+ """Text to Speech synthesis"""
109
+ output_path = "./outputs/tts_output.wav"
110
+ os.makedirs("./outputs", exist_ok=True)
111
+
112
+ try:
113
+ cli.tts(
114
+ transcript=transcript,
115
+ ref_speaker_speech=ref_speaker_audio,
116
+ model_name=model_name,
117
+ guidance_scale=guidance_scale,
118
+ num_steps=num_steps,
119
+ output_path=output_path
120
+ )
121
+ return output_path, "Generation successful!"
122
+ except Exception as e:
123
+ return None, f"Error: {str(e)}"
124
+
125
+
126
+ @spaces.GPU(duration=60)
127
+ def singing_voice_synthesis(
128
+ singer,
129
+ lyric,
130
+ notes,
131
+ note_durations,
132
+ model_name,
133
+ guidance_scale,
134
+ num_steps,
135
+ progress=gr.Progress(track_tqdm=True)
136
+ ):
137
+ """Singing Voice Synthesis"""
138
+ output_path = "./outputs/svs_output.wav"
139
+ os.makedirs("./outputs", exist_ok=True)
140
+
141
+ try:
142
+ music_score = f"{lyric}<sep>{notes}<sep>{note_durations}"
143
+ cli.svs(
144
+ singer=singer,
145
+ music_score=music_score,
146
+ model_name=model_name,
147
+ guidance_scale=guidance_scale,
148
+ num_steps=num_steps,
149
+ output_path=output_path
150
+ )
151
+ return output_path, "Generation successful!"
152
+ except Exception as e:
153
+ return None, f"Error: {str(e)}"
154
+
155
+
156
+ @spaces.GPU(duration=60)
157
+ def speech_enhancement(
158
+ noisy_audio,
159
+ model_name,
160
+ guidance_scale,
161
+ num_steps,
162
+ progress=gr.Progress(track_tqdm=True)
163
+ ):
164
+ """Speech Enhancement"""
165
+ output_path = "./outputs/se_output.wav"
166
+ os.makedirs("./outputs", exist_ok=True)
167
+
168
+ try:
169
+ cli.se(
170
+ noisy_speech=noisy_audio,
171
+ model_name=model_name,
172
+ guidance_scale=guidance_scale,
173
+ num_steps=num_steps,
174
+ output_path=output_path
175
+ )
176
+ return output_path, "Enhancement successful!"
177
+ except Exception as e:
178
+ return None, f"Error: {str(e)}"
179
+
180
+
181
+ @spaces.GPU(duration=60)
182
+ def audio_super_resolution(
183
+ low_sr_audio,
184
+ model_name,
185
+ guidance_scale,
186
+ num_steps,
187
+ progress=gr.Progress(track_tqdm=True)
188
+ ):
189
+ """Audio Super Resolution"""
190
+ output_path = "./outputs/sr_output.wav"
191
+ os.makedirs("./outputs", exist_ok=True)
192
+
193
+ try:
194
+ cli.sr(
195
+ low_sr_audio=low_sr_audio,
196
+ model_name=model_name,
197
+ guidance_scale=guidance_scale,
198
+ num_steps=num_steps,
199
+ output_path=output_path
200
+ )
201
+ return output_path, "Super-resolution successful!"
202
+ except Exception as e:
203
+ return None, f"Error: {str(e)}"
204
+
205
+
206
+ @spaces.GPU(duration=60)
207
+ def video_to_audio(
208
+ video,
209
+ model_name,
210
+ guidance_scale,
211
+ num_steps,
212
+ progress=gr.Progress(track_tqdm=True)
213
+ ):
214
+ """Video to Audio generation"""
215
+ output_path = "./outputs/v2a_output.mp4"
216
+ os.makedirs("./outputs", exist_ok=True)
217
+
218
+ try:
219
+ cli.v2a(
220
+ video=video,
221
+ model_name=model_name,
222
+ guidance_scale=guidance_scale,
223
+ num_steps=num_steps,
224
+ output_path=output_path
225
+ )
226
+ return output_path, "Generation successful!"
227
+ except Exception as e:
228
+ return None, f"Error: {str(e)}"
229
+
230
+
231
+ # Create Gradio Interface
232
+ with gr.Blocks(
233
+ title="UniFlow-Audio Inference Demo", theme=gr.themes.Soft()
234
+ ) as demo:
235
+ gr.Markdown("# 🔊 UniFlow-Audio Inference Demo")
236
+ gr.Markdown("Multi-task Audio Generation System based on UniFlow-Audio")
237
+
238
+ with gr.Tabs():
239
+ # Tab 1: Text to Audio
240
+ with gr.Tab("📢 Text to Audio (T2A)"):
241
+ with gr.Row():
242
+ with gr.Column():
243
+ t2a_caption = gr.Textbox(
244
+ label="Audio Caption",
245
+ placeholder="e.g., a man is speaking then a dog barks",
246
+ lines=3
247
+ )
248
+ t2a_model = gr.Dropdown(
249
+ label="Model Name",
250
+ choices=MODEL_CHOICES,
251
+ value=DEFAULT_MODEL
252
+ )
253
+ with gr.Row():
254
+ t2a_guidance = gr.Slider(
255
+ label="Guidance Scale",
256
+ minimum=1.0,
257
+ maximum=10.0,
258
+ value=5.0,
259
+ step=0.5
260
+ )
261
+ t2a_steps = gr.Slider(
262
+ label="Sampling Steps",
263
+ minimum=10,
264
+ maximum=100,
265
+ value=25,
266
+ step=1
267
+ )
268
+ t2a_button = gr.Button("Generate Audio", variant="primary")
269
+
270
+ with gr.Column():
271
+ t2a_output = gr.Audio(
272
+ label="Generated Audio", type="filepath"
273
+ )
274
+ t2a_status = gr.Textbox(label="Status")
275
+
276
+ t2a_button.click(
277
+ fn=text_to_audio,
278
+ inputs=[t2a_caption, t2a_model, t2a_guidance, t2a_steps],
279
+ outputs=[t2a_output, t2a_status]
280
+ )
281
+
282
+ gr.Examples(
283
+ examples=[
284
+ ["a man is speaking then a dog barks", 5.0, 25],
285
+ ["footsteps on wooden floor", 5.0, 25],
286
+ ],
287
+ inputs=[t2a_caption, t2a_guidance, t2a_steps]
288
+ )
289
+
290
+ # Tab 2: Text to Music
291
+ with gr.Tab("🎼 Text to Music (T2M)"):
292
+ with gr.Row():
293
+ with gr.Column():
294
+ t2m_caption = gr.Textbox(
295
+ label="Music Caption",
296
+ placeholder="e.g., pop music with a male singing rap",
297
+ lines=3
298
+ )
299
+ t2m_model = gr.Dropdown(
300
+ label="Model Name",
301
+ choices=MODEL_CHOICES,
302
+ value=DEFAULT_MODEL
303
+ )
304
+ with gr.Row():
305
+ t2m_guidance = gr.Slider(
306
+ label="Guidance Scale",
307
+ minimum=1.0,
308
+ maximum=10.0,
309
+ value=5.0,
310
+ step=0.5
311
+ )
312
+ t2m_steps = gr.Slider(
313
+ label="Sampling Steps",
314
+ minimum=10,
315
+ maximum=100,
316
+ value=25,
317
+ step=1
318
+ )
319
+ t2m_button = gr.Button("Generate Music", variant="primary")
320
+
321
+ with gr.Column():
322
+ t2m_output = gr.Audio(
323
+ label="Generated Music", type="filepath"
324
+ )
325
+ t2m_status = gr.Textbox(label="Status")
326
+
327
+ t2m_button.click(
328
+ fn=text_to_music,
329
+ inputs=[t2m_caption, t2m_model, t2m_guidance, t2m_steps],
330
+ outputs=[t2m_output, t2m_status]
331
+ )
332
+
333
+ gr.Examples(
334
+ examples=[
335
+ ["pop music with a male singing rap", 5.0, 25],
336
+ ["classical piano solo", 5.0, 25],
337
+ ],
338
+ inputs=[t2m_caption, t2m_guidance, t2m_steps]
339
+ )
340
+
341
+ # Tab 3: Text to Speech
342
+ with gr.Tab("��️ Text to Speech (TTS)"):
343
+ with gr.Row():
344
+ with gr.Column():
345
+ tts_transcript = gr.Textbox(
346
+ label="Text to Synthesize",
347
+ placeholder="e.g., Hello this is a special sentence",
348
+ lines=3
349
+ )
350
+ tts_ref_audio = gr.Audio(
351
+ label="Reference Speaker Audio", type="filepath"
352
+ )
353
+ tts_model = gr.Dropdown(
354
+ label="Model Name",
355
+ choices=MODEL_CHOICES,
356
+ value=DEFAULT_MODEL
357
+ )
358
+ with gr.Row():
359
+ tts_guidance = gr.Slider(
360
+ label="Guidance Scale",
361
+ minimum=1.0,
362
+ maximum=10.0,
363
+ value=5.0,
364
+ step=0.5
365
+ )
366
+ tts_steps = gr.Slider(
367
+ label="Sampling Steps",
368
+ minimum=10,
369
+ maximum=100,
370
+ value=25,
371
+ step=1
372
+ )
373
+ tts_button = gr.Button(
374
+ "Synthesize Speech", variant="primary"
375
+ )
376
+
377
+ with gr.Column():
378
+ tts_output = gr.Audio(
379
+ label="Synthesized Speech", type="filepath"
380
+ )
381
+ tts_status = gr.Textbox(label="Status")
382
+
383
+ tts_button.click(
384
+ fn=text_to_speech,
385
+ inputs=[
386
+ tts_transcript, tts_ref_audio, tts_model, tts_guidance,
387
+ tts_steps
388
+ ],
389
+ outputs=[tts_output, tts_status]
390
+ )
391
+
392
+ gr.Examples(
393
+ examples=[
394
+ [
395
+ "Hello this is a special sentence with zyloph",
396
+ "./data/egs/tts_speaker_ref.wav", 5.0, 25
397
+ ],
398
+ [
399
+ "The quick brown fox jumps over the lazy dog",
400
+ "./data/egs/tts_speaker_ref.wav", 5.0, 25
401
+ ],
402
+ ],
403
+ inputs=[
404
+ tts_transcript, tts_ref_audio, tts_guidance, tts_steps
405
+ ]
406
+ )
407
+
408
+ # Tab 4: Singing Voice Synthesis
409
+ with gr.Tab("🎤 Singing Voice Synthesis (SVS)"):
410
+ with gr.Row():
411
+ with gr.Column():
412
+ svs_singer = gr.Dropdown(
413
+ label="Singer",
414
+ choices=[
415
+ "Alto-1", "Alto-2", "Alto-3", "Alto-4", "Alto-5",
416
+ "Alto-6", "Alto-7", "Bass-1", "Bass-2", "Bass-3",
417
+ "Soprano-1", "Soprano-2", "Soprano-3", "Tenor-1",
418
+ "Tenor-2", "Tenor-3", "Tenor-4", "Tenor-5",
419
+ "Tenor-6", "Tenor-7"
420
+ ],
421
+ value="Alto-2"
422
+ )
423
+ svs_lyric = gr.Textbox(
424
+ label="Lyrics",
425
+ placeholder="e.g., AP你要相信AP相信我们会像童话故事里AP",
426
+ lines=2
427
+ )
428
+ svs_notes = gr.Textbox(
429
+ label="Note Sequence",
430
+ placeholder="e.g., rest | G#3 | A#3 C4 | D#4 | ...",
431
+ lines=2
432
+ )
433
+ svs_durations = gr.Textbox(
434
+ label="Note Durations",
435
+ placeholder=
436
+ "e.g., 0.14 | 0.47 | 0.1905 0.1895 | 0.41 | ...",
437
+ lines=2
438
+ )
439
+ svs_model = gr.Dropdown(
440
+ label="Model Name",
441
+ choices=MODEL_CHOICES,
442
+ value=DEFAULT_MODEL
443
+ )
444
+ with gr.Row():
445
+ svs_guidance = gr.Slider(
446
+ label="Guidance Scale",
447
+ minimum=1.0,
448
+ maximum=10.0,
449
+ value=5.0,
450
+ step=0.5
451
+ )
452
+ svs_steps = gr.Slider(
453
+ label="Sampling Steps",
454
+ minimum=10,
455
+ maximum=100,
456
+ value=25,
457
+ step=1
458
+ )
459
+ svs_button = gr.Button(
460
+ "Synthesize Singing", variant="primary"
461
+ )
462
+
463
+ with gr.Column():
464
+ svs_output = gr.Audio(
465
+ label="Synthesized Singing", type="filepath"
466
+ )
467
+ svs_status = gr.Textbox(label="Status")
468
+
469
+ svs_button.click(
470
+ fn=singing_voice_synthesis,
471
+ inputs=[
472
+ svs_singer, svs_lyric, svs_notes, svs_durations, svs_model,
473
+ svs_guidance, svs_steps
474
+ ],
475
+ outputs=[svs_output, svs_status]
476
+ )
477
+
478
+ gr.Examples(
479
+ examples=[
480
+ [
481
+ "Alto-2", "AP你要相信AP相信我们会像童话故事里AP",
482
+ "rest | G#3 | A#3 C4 | D#4 | D#4 F4 | rest | E4 F4 | F4 | D#4 A#3 | A#3 | A#3 | C#4 | B3 C4 | C#4 | B3 C4 | A#3 | G#3 | rest",
483
+ "0.14 | 0.47 | 0.1905 0.1895 | 0.41 | 0.3005 0.3895 | 0.21 | 0.2391 0.1809 | 0.32 | 0.4105 0.2095 | 0.35 | 0.43 | 0.45 | 0.2309 0.2291 | 0.48 | 0.225 0.195 | 0.29 | 0.71 | 0.14",
484
+ 5.0, 25
485
+ ],
486
+ ],
487
+ inputs=[
488
+ svs_singer, svs_lyric, svs_notes, svs_durations,
489
+ svs_guidance, svs_steps
490
+ ]
491
+ )
492
+
493
+ gr.Markdown(
494
+ """
495
+ ### Usage Instructions
496
+ - **Lyrics Format**: Use AP for pauses, e.g., `AP你要相信AP相信我们会像童话故事里AP`
497
+ - **Note Format**: Separate with `|`, use spaces for simultaneous notes, use `rest` for rests
498
+ - **Duration Format**: Note durations in seconds, separated by `|`
499
+ """
500
+ )
501
+
502
+ # Tab 5: Speech Enhancement
503
+ with gr.Tab("🔊 Speech Enhancement (SE)"):
504
+ with gr.Row():
505
+ with gr.Column():
506
+ se_input = gr.Audio(label="Noisy Speech", type="filepath")
507
+ se_model = gr.Dropdown(
508
+ label="Model Name",
509
+ choices=MODEL_CHOICES,
510
+ value=DEFAULT_MODEL
511
+ )
512
+ with gr.Row():
513
+ se_guidance = gr.Slider(
514
+ label="Guidance Scale",
515
+ minimum=1.0,
516
+ maximum=10.0,
517
+ value=1.0,
518
+ step=0.5
519
+ )
520
+ se_steps = gr.Slider(
521
+ label="Sampling Steps",
522
+ minimum=10,
523
+ maximum=100,
524
+ value=25,
525
+ step=1
526
+ )
527
+ se_button = gr.Button("Enhance Speech", variant="primary")
528
+
529
+ with gr.Column():
530
+ se_output = gr.Audio(
531
+ label="Enhanced Speech", type="filepath"
532
+ )
533
+ se_status = gr.Textbox(label="Status")
534
+
535
+ se_button.click(
536
+ fn=speech_enhancement,
537
+ inputs=[se_input, se_model, se_guidance, se_steps],
538
+ outputs=[se_output, se_status]
539
+ )
540
+
541
+ gr.Examples(
542
+ examples=[
543
+ ["./data/egs/se_noisy_sample.wav", 1.0, 25],
544
+ ],
545
+ inputs=[se_input, se_guidance, se_steps]
546
+ )
547
+
548
+ # Tab 6: Audio Super Resolution
549
+ with gr.Tab("⬆️ Audio Super Resolution (SR)"):
550
+ with gr.Row():
551
+ with gr.Column():
552
+ sr_input = gr.Audio(
553
+ label="Low Sample Rate Audio", type="filepath"
554
+ )
555
+ sr_model = gr.Dropdown(
556
+ label="Model Name",
557
+ choices=MODEL_CHOICES,
558
+ value=DEFAULT_MODEL
559
+ )
560
+ with gr.Row():
561
+ sr_guidance = gr.Slider(
562
+ label="Guidance Scale",
563
+ minimum=1.0,
564
+ maximum=10.0,
565
+ value=1.0,
566
+ step=0.5
567
+ )
568
+ sr_steps = gr.Slider(
569
+ label="Sampling Steps",
570
+ minimum=10,
571
+ maximum=100,
572
+ value=25,
573
+ step=1
574
+ )
575
+ sr_button = gr.Button(
576
+ "Super-Resolve Audio", variant="primary"
577
+ )
578
+
579
+ with gr.Column():
580
+ sr_output = gr.Audio(
581
+ label="High Sample Rate Audio", type="filepath"
582
+ )
583
+ sr_status = gr.Textbox(label="Status")
584
+
585
+ sr_button.click(
586
+ fn=audio_super_resolution,
587
+ inputs=[sr_input, sr_model, sr_guidance, sr_steps],
588
+ outputs=[sr_output, sr_status]
589
+ )
590
+
591
+ gr.Examples(
592
+ examples=[
593
+ ["./data/egs/sr_low_sr_sample.wav", 1.0, 25],
594
+ ],
595
+ inputs=[sr_input, sr_guidance, sr_steps]
596
+ )
597
+
598
+ # Tab 7: Video to Audio
599
+ with gr.Tab("🎬 Video to Audio (V2A)"):
600
+ with gr.Row():
601
+ with gr.Column():
602
+ v2a_input = gr.Video(label="Input Video")
603
+ v2a_model = gr.Dropdown(
604
+ label="Model Name",
605
+ choices=MODEL_CHOICES,
606
+ value=DEFAULT_MODEL
607
+ )
608
+ with gr.Row():
609
+ v2a_guidance = gr.Slider(
610
+ label="Guidance Scale",
611
+ minimum=1.0,
612
+ maximum=10.0,
613
+ value=5.0,
614
+ step=0.5
615
+ )
616
+ v2a_steps = gr.Slider(
617
+ label="Sampling Steps",
618
+ minimum=10,
619
+ maximum=100,
620
+ value=25,
621
+ step=1
622
+ )
623
+ v2a_button = gr.Button("Generate Audio", variant="primary")
624
+
625
+ with gr.Column():
626
+ v2a_output = gr.Video(label="Video with Audio")
627
+ v2a_status = gr.Textbox(label="Status")
628
+
629
+ v2a_button.click(
630
+ fn=video_to_audio,
631
+ inputs=[v2a_input, v2a_model, v2a_guidance, v2a_steps],
632
+ outputs=[v2a_output, v2a_status]
633
+ )
634
+
635
+ gr.Examples(
636
+ examples=[
637
+ ["./data/egs/v2a_video_sample.mp4", 5.0, 25],
638
+ ],
639
+ inputs=[v2a_input, v2a_guidance, v2a_steps]
640
+ )
641
+
642
+ gr.Markdown(
643
+ """
644
+ ---
645
+ ### 📝 Notes
646
+ - **Model Name**: Choose from `UniFlow-Audio-large`, `UniFlow-Audio-medium`, or `UniFlow-Audio-small`
647
+ - **Guidance Scale**: Controls the guidance strength of the input condition on the output
648
+ - **Sampling Steps**: Number of flow matching sampling steps
649
+
650
+ 💡 Tip: Models will be automatically downloaded on first run, please be patient
651
+ """
652
+ )
653
+
654
+ if __name__ == "__main__":
655
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
data/egs/se_noisy_sample.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e38f7994348c3af00745273cb55fced24b4508320dc2ed71c61df67607a1880f
3
+ size 316064
data/egs/sr_low_sr_sample.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:31d537be6dab1221b453ad5a34d5092452fc48fe6746957405052fc6c99ff9ea
3
+ size 323628
data/egs/tts_speaker_ref.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7e5240bf194d56ad0762501f3ef321436ce152b99853fe729047bfc72a5d9563
3
+ size 615884
data/egs/v2a_video_sample.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:da83f1e241ba50b97f45e1a869543ed0aa2e862ce583ea77285ea1d94e7a5d0b
3
+ size 711516
inference_cli.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ from typing import Any, Callable
4
+ import json
5
+
6
+ import fire
7
+ import torch
8
+ import torchaudio
9
+ import soundfile as sf
10
+ import numpy as np
11
+
12
+ from modeling_uniflow_audio import UniFlowAudioModel
13
+ from constants import TIME_ALIGNED_TASKS, NON_TIME_ALIGNED_TASKS
14
+
15
+
16
+ class InferenceCLI:
17
+ def __init__(self):
18
+ self.model_name = None
19
+ self.device = torch.device(
20
+ "cuda" if torch.cuda.is_available() else "cpu"
21
+ )
22
+ self.g2p = None
23
+ self.speaker_model = None
24
+ self.svs_processor = None
25
+ self.singer_mapping = None
26
+
27
+ self.video_preprocessor = None
28
+ self.video_size = (256, 256)
29
+ self.video_fps = 10
30
+
31
+ def init_model(self, model_name):
32
+ self.model_name = model_name
33
+ self.model = UniFlowAudioModel(f"wsntxxn/{model_name}")
34
+ self.model.to(self.device)
35
+ self.sample_rate = self.model.config["sample_rate"]
36
+
37
+ def init_speaker_model(self, ):
38
+ import wespeaker
39
+
40
+ if self.speaker_model is None:
41
+ self.speaker_model = wespeaker.load_model("english")
42
+ self.speaker_model.set_device(self.device)
43
+
44
+ def init_svs_processor(self, ):
45
+ from utils.diffsinger_utilities import SVSInputConverter, TokenTextEncoder
46
+
47
+ if self.svs_processor is None:
48
+ phoneme_list = json.load(open(self.model.svs_phone_set_path, "r"))
49
+ self.svs_processor = {
50
+ "converter":
51
+ SVSInputConverter(
52
+ self.model.svs_singer_mapping, self.model.svs_pinyin2ph
53
+ ),
54
+ "tokenizer":
55
+ TokenTextEncoder(
56
+ None, vocab_list=phoneme_list, replace_oov=','
57
+ )
58
+ }
59
+
60
+ def init_video_preprocessor(self, ):
61
+ if self.video_preprocessor is None:
62
+ from transformers import CLIPImageProcessor, CLIPVisionModel
63
+ import torchvision
64
+ self.video_preprocessor = {
65
+ "transform":
66
+ torchvision.transforms.Resize(self.video_size),
67
+ "processor":
68
+ CLIPImageProcessor.
69
+ from_pretrained("openai/clip-vit-large-patch14"),
70
+ "encoder":
71
+ CLIPVisionModel.
72
+ from_pretrained("openai/clip-vit-large-patch14")
73
+ }
74
+ self.video_preprocessor["encoder"].to(self.device)
75
+ self.video_preprocessor["encoder"].eval()
76
+
77
+ def on_inference_start(self, model_name):
78
+ if self.model_name is None or model_name != self.model_name:
79
+ self.init_model(model_name)
80
+
81
+ @staticmethod
82
+ def add_prehook(func: Callable, ):
83
+ def wrapper(self, *args, **kwargs):
84
+ model_name = kwargs["model_name"]
85
+ self.on_inference_start(model_name)
86
+ return func(self, *args, **kwargs)
87
+
88
+ return wrapper
89
+
90
+ @add_prehook
91
+ def t2a(
92
+ self,
93
+ caption: str,
94
+ model_name: str = "UniFlow-Audio-large",
95
+ instruction: str | None = None,
96
+ instruction_idx: int | None = None,
97
+ guidance_scale: float = 5.0,
98
+ num_steps: int = 25,
99
+ output_path: str = "./output.wav",
100
+ ):
101
+ self._run_inference(
102
+ content=caption,
103
+ task="text_to_audio",
104
+ instruction=instruction,
105
+ instruction_idx=instruction_idx,
106
+ model_name=model_name,
107
+ guidance_scale=guidance_scale,
108
+ num_steps=num_steps,
109
+ output_path=output_path
110
+ )
111
+
112
+ @add_prehook
113
+ def t2m(
114
+ self,
115
+ caption: str,
116
+ model_name: str,
117
+ instruction: str | None = None,
118
+ instruction_idx: int | None = None,
119
+ guidance_scale: float = 5.0,
120
+ num_steps: int = 25,
121
+ output_path: str = "./output.wav",
122
+ ):
123
+ self._run_inference(
124
+ content=caption,
125
+ task="text_to_music",
126
+ model_name=model_name,
127
+ instruction=instruction,
128
+ instruction_idx=instruction_idx,
129
+ guidance_scale=guidance_scale,
130
+ num_steps=num_steps,
131
+ output_path=output_path,
132
+ )
133
+
134
+ @add_prehook
135
+ def tts(
136
+ self,
137
+ transcript: str,
138
+ ref_speaker_speech: str,
139
+ model_name: str = "UniFlow-Audio-large",
140
+ instruction: str | None = None,
141
+ instruction_idx: int | None = None,
142
+ guidance_scale: float = 5.0,
143
+ num_steps: int = 25,
144
+ output_path: str = "./output.wav",
145
+ ):
146
+ from g2p_en import G2p
147
+ import nltk
148
+
149
+ self.init_speaker_model()
150
+
151
+ if not self.g2p:
152
+ nltk.download("averaged_perceptron_tagger_eng")
153
+ self.g2p = G2p()
154
+
155
+ phonemes = self.g2p(transcript)
156
+ phone_indices = [
157
+ self.model.tts_phone2id.get(
158
+ p, self.model.tts_phone2id.get("spn", 0)
159
+ ) for p in phonemes
160
+ ]
161
+ xvector = self.speaker_model.extract_embedding(ref_speaker_speech)
162
+
163
+ content = {
164
+ "phoneme": np.array(phone_indices, dtype=np.int64),
165
+ "spk": np.array(xvector, dtype=np.float32),
166
+ }
167
+ self._run_inference(
168
+ content=content,
169
+ task="text_to_speech",
170
+ model_name=model_name,
171
+ instruction=instruction,
172
+ instruction_idx=instruction_idx,
173
+ guidance_scale=guidance_scale,
174
+ num_steps=num_steps,
175
+ output_path=output_path,
176
+ )
177
+
178
+ @add_prehook
179
+ def _audio_input_inference(
180
+ self,
181
+ input_audio: str,
182
+ task: str,
183
+ model_name: str,
184
+ instruction: str | None = None,
185
+ instruction_idx: int | None = None,
186
+ guidance_scale: float = 5.0,
187
+ num_steps: int = 25,
188
+ output_path: str = "./output.wav",
189
+ ):
190
+ waveform, orig_sr = torchaudio.load(input_audio)
191
+ waveform = waveform.mean(0)
192
+ waveform = torchaudio.functional.resample(
193
+ waveform, orig_freq=orig_sr, new_freq=self.sample_rate
194
+ )
195
+ self._run_inference(
196
+ content=waveform,
197
+ task=task,
198
+ instruction=instruction,
199
+ instruction_idx=instruction_idx,
200
+ model_name=model_name,
201
+ guidance_scale=guidance_scale,
202
+ num_steps=num_steps,
203
+ output_path=output_path
204
+ )
205
+
206
+ def se(
207
+ self,
208
+ noisy_speech: str,
209
+ model_name: str,
210
+ instruction: str | None = None,
211
+ instruction_idx: int | None = None,
212
+ guidance_scale: float = 1.0,
213
+ num_steps: int = 25,
214
+ output_path: str = "./output.wav",
215
+ ):
216
+ self._audio_input_inference(
217
+ input_audio=noisy_speech,
218
+ task="speech_enhancement",
219
+ instruction=instruction,
220
+ instruction_idx=instruction_idx,
221
+ model_name=model_name,
222
+ guidance_scale=guidance_scale,
223
+ num_steps=num_steps,
224
+ output_path=output_path
225
+ )
226
+
227
+ def sr(
228
+ self,
229
+ low_sr_audio: str,
230
+ model_name: str,
231
+ instruction: str | None = None,
232
+ instruction_idx: int | None = None,
233
+ guidance_scale: float = 1.0,
234
+ num_steps: int = 25,
235
+ output_path: str = "./output.wav",
236
+ ):
237
+ self._audio_input_inference(
238
+ input_audio=low_sr_audio,
239
+ task="audio_super_resolution",
240
+ instruction=instruction,
241
+ instruction_idx=instruction_idx,
242
+ model_name=model_name,
243
+ guidance_scale=guidance_scale,
244
+ num_steps=num_steps,
245
+ output_path=output_path
246
+ )
247
+
248
+ @add_prehook
249
+ def v2a(
250
+ self,
251
+ video: str,
252
+ model_name: str,
253
+ instruction: str | None = None,
254
+ instruction_idx: int | None = None,
255
+ guidance_scale: float = 5.0,
256
+ num_steps: int = 25,
257
+ output_path: str = "./output.mp4",
258
+ ):
259
+ from utils.video import read_video_frames, merge_audio_video
260
+
261
+ self.init_video_preprocessor()
262
+ video_path = video
263
+ video = read_video_frames(
264
+ video,
265
+ duration=None,
266
+ fps=self.video_fps,
267
+ video_size=self.video_size,
268
+ resize_transform=self.video_preprocessor["transform"]
269
+ )
270
+ pixel_values = self.video_preprocessor["processor"](
271
+ images=video, return_tensors="pt"
272
+ ).pixel_values.to(self.device)
273
+
274
+ with torch.no_grad():
275
+ output = self.video_preprocessor["encoder"](pixel_values)
276
+ video_feature = output.pooler_output
277
+
278
+ waveform = self._run_inference(
279
+ content=video_feature,
280
+ task="video_to_audio",
281
+ model_name=model_name,
282
+ instruction=instruction,
283
+ instruction_idx=instruction_idx,
284
+ guidance_scale=guidance_scale,
285
+ num_steps=num_steps,
286
+ output_path=output_path,
287
+ )
288
+
289
+ merge_audio_video(
290
+ waveform, video_path, output_path, audio_fps=self.sample_rate
291
+ )
292
+
293
+ @add_prehook
294
+ def svs(
295
+ self,
296
+ singer: str,
297
+ music_score: str,
298
+ model_name: str,
299
+ instruction: str | None = None,
300
+ instruction_idx: int | None = None,
301
+ guidance_scale: float = 5.0,
302
+ num_steps: int = 25,
303
+ output_path: str = "./output.wav",
304
+ ):
305
+ self.init_svs_processor()
306
+ text, note, note_dur = music_score.split('<sep>')
307
+ if singer not in self.model.svs_singer_mapping:
308
+ print(f"Unsupported singer {singer}, available singers: ")
309
+ print(list(self.model.svs_singer_mapping.keys()))
310
+ raise KeyError
311
+
312
+ midi = self.svs_processor["converter"].preprocess_input({
313
+ "spk_name": singer,
314
+ "text": text,
315
+ "notes": note,
316
+ "notes_duration": note_dur,
317
+ })
318
+ midi["phoneme"] = self.svs_processor["tokenizer"].encode(
319
+ midi["phoneme"]
320
+ )
321
+ self._run_inference(
322
+ content=midi,
323
+ task="singing_voice_synthesis",
324
+ model_name=model_name,
325
+ instruction=instruction,
326
+ instruction_idx=instruction_idx,
327
+ guidance_scale=guidance_scale,
328
+ num_steps=num_steps,
329
+ output_path=output_path,
330
+ )
331
+
332
+ def _run_inference(
333
+ self,
334
+ content: Any,
335
+ task: str,
336
+ model_name: str,
337
+ instruction: str | None = None,
338
+ instruction_idx: int | None = None,
339
+ guidance_scale: float = 5.0,
340
+ num_steps: int = 25,
341
+ output_path: str = "./output.wav",
342
+ ):
343
+ if self.model_name is None or model_name != self.model_name:
344
+ self.init_model(model_name)
345
+ if task in TIME_ALIGNED_TASKS:
346
+ is_time_aligned = True
347
+ else:
348
+ is_time_aligned = False
349
+ if instruction:
350
+ instruction = [instruction]
351
+ if instruction_idx:
352
+ instruction_idx = [instruction_idx]
353
+
354
+ waveform = self.model.sample(
355
+ content=[content],
356
+ task=[task],
357
+ is_time_aligned=[is_time_aligned],
358
+ instruction=instruction,
359
+ instruction_idx=instruction_idx,
360
+ num_steps=num_steps,
361
+ guidance_scale=guidance_scale,
362
+ disable_progress=False
363
+ )
364
+ waveform = waveform[0, 0].cpu().numpy()
365
+
366
+ if not output_path.endswith(".mp4"):
367
+ sf.write(output_path, waveform, self.sample_rate)
368
+
369
+ return waveform
370
+
371
+
372
+ if __name__ == "__main__":
373
+ fire.Fire(InferenceCLI)
modeling_uniflow_audio.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Sequence
2
+ from pathlib import Path
3
+ import json
4
+ import shutil
5
+
6
+ import h5py
7
+ from huggingface_hub import snapshot_download
8
+ from omegaconf import OmegaConf
9
+ from safetensors.torch import load_file
10
+ import hydra
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.nn.utils.rnn import pad_sequence
14
+ from transformers import T5EncoderModel, T5Tokenizer
15
+
16
+
17
+ class UniFlowAudioModel(nn.Module):
18
+ def __init__(self, model_name: str = "wsntxxn/UniFlow-Audio-large"):
19
+ assert model_name in (
20
+ "wsntxxn/UniFlow-Audio-large",
21
+ "wsntxxn/UniFlow-Audio-medium",
22
+ "wsntxxn/UniFlow-Audio-small",
23
+ )
24
+ super().__init__()
25
+ model_dir = snapshot_download(repo_id=model_name)
26
+ model_dir = Path(model_dir)
27
+ self.config = OmegaConf.load(model_dir / "config.yaml")
28
+ self.config["model"]["autoencoder"]["pretrained_ckpt"] = str(
29
+ model_dir / self.config["model"]["autoencoder"]["pretrained_ckpt"]
30
+ )
31
+ self.model = hydra.utils.instantiate(
32
+ self.config["model"], _convert_="all"
33
+ )
34
+ state_dict = load_file(model_dir / "model.safetensors")
35
+ self.model.load_pretrained(state_dict)
36
+ self.model.eval()
37
+
38
+ self.g2p_model_path = model_dir / "mfa_g2p" / "english_us_arpa_unhashed.zip"
39
+ if not self.g2p_model_path.exists():
40
+ ori_model_path = (model_dir / "mfa_g2p" /
41
+ "english_us_arpa.zip").resolve()
42
+ shutil.copy(ori_model_path, self.g2p_model_path)
43
+
44
+ self.tts_phone_set_path = model_dir / "mfa_g2p" / "phone_set.json"
45
+ self.build_tts_phone_mapping()
46
+ self.svs_phone_set_path = model_dir / "svs" / "phone_set.json"
47
+ singers = json.load(open(model_dir / "svs" / "spk_set.json", "r"))
48
+ self.svs_singer_mapping = {
49
+ singer: i
50
+ for i, singer in enumerate(singers)
51
+ }
52
+ self.svs_pinyin2ph = model_dir / "svs" / "m4singer_pinyin2ph.txt"
53
+
54
+ self.task_to_instructions = {}
55
+ with h5py.File(model_dir / "instructions" / "t5_embeddings.h5") as hf:
56
+ for key in hf.keys():
57
+ self.task_to_instructions[key] = hf[key][()]
58
+
59
+ self.init_instruction_encoder()
60
+
61
+ def build_tts_phone_mapping(self):
62
+ with open(self.tts_phone_set_path, "r", encoding="utf-8") as f:
63
+ phone_set = json.load(f)
64
+
65
+ self.tts_phone2id = {p: i for i, p in enumerate(phone_set)}
66
+
67
+ def init_instruction_encoder(self):
68
+ self.instruction_tokenizer = T5Tokenizer.from_pretrained(
69
+ "google/flan-t5-large"
70
+ )
71
+ self.instruction_encoder = T5EncoderModel.from_pretrained(
72
+ "google/flan-t5-large"
73
+ )
74
+ self.instruction_encoder.eval()
75
+
76
+ @torch.inference_mode()
77
+ def encode_instruction(self, instruction: list[str], device: torch.device):
78
+ with torch.amp.autocast(enabled=False):
79
+ tokens = self.instruction_tokenizer(
80
+ instruction,
81
+ max_length=self.instruction_tokenizer.model_max_length,
82
+ padding=True,
83
+ truncation=True,
84
+ return_tensors="pt",
85
+ )
86
+ input_ids = tokens.input_ids.to(device)
87
+ attention_mask = tokens.attention_mask.to(device)
88
+ output = self.instruction_encoder(
89
+ input_ids=input_ids, attention_mask=attention_mask
90
+ )
91
+ output = output.last_hidden_state
92
+ length = attention_mask.sum(dim=1)
93
+ return output, length
94
+
95
+ @torch.inference_mode()
96
+ def sample(
97
+ self,
98
+ content: list[Any],
99
+ task: list[str],
100
+ is_time_aligned: Sequence[bool],
101
+ instruction: list[str] | None = None,
102
+ instruction_idx: list[int] | None = None,
103
+ num_steps: int = 20,
104
+ sway_sampling_coef: float | None = -1.0,
105
+ guidance_scale: float = 3.0,
106
+ disable_progress: bool = True,
107
+ ):
108
+ device = self.model.dummy_param.device
109
+
110
+ if instruction is None:
111
+ instructions = []
112
+ instruction_lengths = []
113
+ for sample_idx, task_ in enumerate(task):
114
+ if instruction_idx:
115
+ instruction_idx_ = instruction_idx[sample_idx]
116
+ else:
117
+ instruction_idx_ = 0
118
+ instruction_ = self.task_to_instructions[
119
+ f"{task_}_{instruction_idx_}"]
120
+ instructions.append(torch.as_tensor(instruction_))
121
+ instruction_lengths.append(instruction_.shape[0])
122
+ instructions = pad_sequence(instructions,
123
+ batch_first=True).to(device)
124
+ instruction_lengths = torch.as_tensor(instruction_lengths
125
+ ).to(device)
126
+ else:
127
+ instructions, instruction_lengths = self.encode_instruction(
128
+ instruction, device
129
+ )
130
+
131
+ return self.model.inference(
132
+ content, task, is_time_aligned, instructions, instruction_lengths,
133
+ num_steps, sway_sampling_coef, guidance_scale, disable_progress
134
+ )
models/autoencoder/autoencoder_base.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod, ABC
2
+ from typing import Sequence
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class AutoEncoderBase(ABC):
8
+ def __init__(
9
+ self, downsampling_ratio: int, sample_rate: int,
10
+ latent_shape: Sequence[int | None]
11
+ ):
12
+ self.downsampling_ratio = downsampling_ratio
13
+ self.sample_rate = sample_rate
14
+ self.latent_token_rate = sample_rate // downsampling_ratio
15
+ self.latent_shape = latent_shape
16
+ self.time_dim = latent_shape.index(None) + 1 # the first dim is batch
17
+
18
+ @abstractmethod
19
+ def encode(
20
+ self, waveform: torch.Tensor, waveform_lengths: torch.Tensor
21
+ ) -> tuple[torch.Tensor, torch.Tensor]:
22
+ ...
models/autoencoder/waveform/dac.py ADDED
File without changes
models/autoencoder/waveform/stable_vae.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Literal, Callable
2
+ import math
3
+ from pathlib import Path
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.nn.utils import weight_norm
8
+ import torchaudio
9
+ from alias_free_torch import Activation1d
10
+
11
+ from models.common import LoadPretrainedBase
12
+ from models.autoencoder.autoencoder_base import AutoEncoderBase
13
+ from utils.torch_utilities import remove_key_prefix_factory, create_mask_from_length
14
+
15
+
16
+ # jit script make it 1.4x faster and save GPU memory
17
+ @torch.jit.script
18
+ def snake_beta(x, alpha, beta):
19
+ return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2)
20
+
21
+
22
+ class SnakeBeta(nn.Module):
23
+ def __init__(
24
+ self,
25
+ in_features,
26
+ alpha=1.0,
27
+ alpha_trainable=True,
28
+ alpha_logscale=True
29
+ ):
30
+ super(SnakeBeta, self).__init__()
31
+ self.in_features = in_features
32
+
33
+ # initialize alpha
34
+ self.alpha_logscale = alpha_logscale
35
+ if self.alpha_logscale:
36
+ # log scale alphas initialized to zeros
37
+ self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
38
+ self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
39
+ else:
40
+ # linear scale alphas initialized to ones
41
+ self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
42
+ self.beta = nn.Parameter(torch.ones(in_features) * alpha)
43
+
44
+ self.alpha.requires_grad = alpha_trainable
45
+ self.beta.requires_grad = alpha_trainable
46
+
47
+ # self.no_div_by_zero = 0.000000001
48
+
49
+ def forward(self, x):
50
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1)
51
+ # line up with x to [B, C, T]
52
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
53
+ if self.alpha_logscale:
54
+ alpha = torch.exp(alpha)
55
+ beta = torch.exp(beta)
56
+ x = snake_beta(x, alpha, beta)
57
+
58
+ return x
59
+
60
+
61
+ def WNConv1d(*args, **kwargs):
62
+ return weight_norm(nn.Conv1d(*args, **kwargs))
63
+
64
+
65
+ def WNConvTranspose1d(*args, **kwargs):
66
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
67
+
68
+
69
+ def get_activation(
70
+ activation: Literal["elu", "snake", "none"],
71
+ antialias=False,
72
+ channels=None
73
+ ) -> nn.Module:
74
+ if activation == "elu":
75
+ act = nn.ELU()
76
+ elif activation == "snake":
77
+ act = SnakeBeta(channels)
78
+ elif activation == "none":
79
+ act = nn.Identity()
80
+ else:
81
+ raise ValueError(f"Unknown activation {activation}")
82
+
83
+ if antialias:
84
+ act = Activation1d(act)
85
+
86
+ return act
87
+
88
+
89
+ class ResidualUnit(nn.Module):
90
+ def __init__(
91
+ self,
92
+ in_channels,
93
+ out_channels,
94
+ dilation,
95
+ use_snake=False,
96
+ antialias_activation=False
97
+ ):
98
+ super().__init__()
99
+
100
+ self.dilation = dilation
101
+
102
+ padding = (dilation * (7 - 1)) // 2
103
+
104
+ self.layers = nn.Sequential(
105
+ get_activation(
106
+ "snake" if use_snake else "elu",
107
+ antialias=antialias_activation,
108
+ channels=out_channels
109
+ ),
110
+ WNConv1d(
111
+ in_channels=in_channels,
112
+ out_channels=out_channels,
113
+ kernel_size=7,
114
+ dilation=dilation,
115
+ padding=padding
116
+ ),
117
+ get_activation(
118
+ "snake" if use_snake else "elu",
119
+ antialias=antialias_activation,
120
+ channels=out_channels
121
+ ),
122
+ WNConv1d(
123
+ in_channels=out_channels,
124
+ out_channels=out_channels,
125
+ kernel_size=1
126
+ )
127
+ )
128
+
129
+ def forward(self, x):
130
+ res = x
131
+
132
+ #x = checkpoint(self.layers, x)
133
+ x = self.layers(x)
134
+
135
+ return x + res
136
+
137
+
138
+ class EncoderBlock(nn.Module):
139
+ def __init__(
140
+ self,
141
+ in_channels,
142
+ out_channels,
143
+ stride,
144
+ use_snake=False,
145
+ antialias_activation=False
146
+ ):
147
+ super().__init__()
148
+
149
+ self.layers = nn.Sequential(
150
+ ResidualUnit(
151
+ in_channels=in_channels,
152
+ out_channels=in_channels,
153
+ dilation=1,
154
+ use_snake=use_snake
155
+ ),
156
+ ResidualUnit(
157
+ in_channels=in_channels,
158
+ out_channels=in_channels,
159
+ dilation=3,
160
+ use_snake=use_snake
161
+ ),
162
+ ResidualUnit(
163
+ in_channels=in_channels,
164
+ out_channels=in_channels,
165
+ dilation=9,
166
+ use_snake=use_snake
167
+ ),
168
+ get_activation(
169
+ "snake" if use_snake else "elu",
170
+ antialias=antialias_activation,
171
+ channels=in_channels
172
+ ),
173
+ WNConv1d(
174
+ in_channels=in_channels,
175
+ out_channels=out_channels,
176
+ kernel_size=2 * stride,
177
+ stride=stride,
178
+ padding=math.ceil(stride / 2)
179
+ ),
180
+ )
181
+
182
+ def forward(self, x):
183
+ return self.layers(x)
184
+
185
+
186
+ class DecoderBlock(nn.Module):
187
+ def __init__(
188
+ self,
189
+ in_channels,
190
+ out_channels,
191
+ stride,
192
+ use_snake=False,
193
+ antialias_activation=False,
194
+ use_nearest_upsample=False
195
+ ):
196
+ super().__init__()
197
+
198
+ if use_nearest_upsample:
199
+ upsample_layer = nn.Sequential(
200
+ nn.Upsample(scale_factor=stride, mode="nearest"),
201
+ WNConv1d(
202
+ in_channels=in_channels,
203
+ out_channels=out_channels,
204
+ kernel_size=2 * stride,
205
+ stride=1,
206
+ bias=False,
207
+ padding='same'
208
+ )
209
+ )
210
+ else:
211
+ upsample_layer = WNConvTranspose1d(
212
+ in_channels=in_channels,
213
+ out_channels=out_channels,
214
+ kernel_size=2 * stride,
215
+ stride=stride,
216
+ padding=math.ceil(stride / 2)
217
+ )
218
+
219
+ self.layers = nn.Sequential(
220
+ get_activation(
221
+ "snake" if use_snake else "elu",
222
+ antialias=antialias_activation,
223
+ channels=in_channels
224
+ ),
225
+ upsample_layer,
226
+ ResidualUnit(
227
+ in_channels=out_channels,
228
+ out_channels=out_channels,
229
+ dilation=1,
230
+ use_snake=use_snake
231
+ ),
232
+ ResidualUnit(
233
+ in_channels=out_channels,
234
+ out_channels=out_channels,
235
+ dilation=3,
236
+ use_snake=use_snake
237
+ ),
238
+ ResidualUnit(
239
+ in_channels=out_channels,
240
+ out_channels=out_channels,
241
+ dilation=9,
242
+ use_snake=use_snake
243
+ ),
244
+ )
245
+
246
+ def forward(self, x):
247
+ return self.layers(x)
248
+
249
+
250
+ class OobleckEncoder(nn.Module):
251
+ def __init__(
252
+ self,
253
+ in_channels=2,
254
+ channels=128,
255
+ latent_dim=32,
256
+ c_mults=[1, 2, 4, 8],
257
+ strides=[2, 4, 8, 8],
258
+ use_snake=False,
259
+ antialias_activation=False
260
+ ):
261
+ super().__init__()
262
+
263
+ c_mults = [1] + c_mults
264
+
265
+ self.depth = len(c_mults)
266
+
267
+ layers = [
268
+ WNConv1d(
269
+ in_channels=in_channels,
270
+ out_channels=c_mults[0] * channels,
271
+ kernel_size=7,
272
+ padding=3
273
+ )
274
+ ]
275
+
276
+ for i in range(self.depth - 1):
277
+ layers += [
278
+ EncoderBlock(
279
+ in_channels=c_mults[i] * channels,
280
+ out_channels=c_mults[i + 1] * channels,
281
+ stride=strides[i],
282
+ use_snake=use_snake
283
+ )
284
+ ]
285
+
286
+ layers += [
287
+ get_activation(
288
+ "snake" if use_snake else "elu",
289
+ antialias=antialias_activation,
290
+ channels=c_mults[-1] * channels
291
+ ),
292
+ WNConv1d(
293
+ in_channels=c_mults[-1] * channels,
294
+ out_channels=latent_dim,
295
+ kernel_size=3,
296
+ padding=1
297
+ )
298
+ ]
299
+
300
+ self.layers = nn.Sequential(*layers)
301
+
302
+ def forward(self, x):
303
+ return self.layers(x)
304
+
305
+
306
+ class OobleckDecoder(nn.Module):
307
+ def __init__(
308
+ self,
309
+ out_channels=2,
310
+ channels=128,
311
+ latent_dim=32,
312
+ c_mults=[1, 2, 4, 8],
313
+ strides=[2, 4, 8, 8],
314
+ use_snake=False,
315
+ antialias_activation=False,
316
+ use_nearest_upsample=False,
317
+ final_tanh=True
318
+ ):
319
+ super().__init__()
320
+
321
+ c_mults = [1] + c_mults
322
+
323
+ self.depth = len(c_mults)
324
+
325
+ layers = [
326
+ WNConv1d(
327
+ in_channels=latent_dim,
328
+ out_channels=c_mults[-1] * channels,
329
+ kernel_size=7,
330
+ padding=3
331
+ ),
332
+ ]
333
+
334
+ for i in range(self.depth - 1, 0, -1):
335
+ layers += [
336
+ DecoderBlock(
337
+ in_channels=c_mults[i] * channels,
338
+ out_channels=c_mults[i - 1] * channels,
339
+ stride=strides[i - 1],
340
+ use_snake=use_snake,
341
+ antialias_activation=antialias_activation,
342
+ use_nearest_upsample=use_nearest_upsample
343
+ )
344
+ ]
345
+
346
+ layers += [
347
+ get_activation(
348
+ "snake" if use_snake else "elu",
349
+ antialias=antialias_activation,
350
+ channels=c_mults[0] * channels
351
+ ),
352
+ WNConv1d(
353
+ in_channels=c_mults[0] * channels,
354
+ out_channels=out_channels,
355
+ kernel_size=7,
356
+ padding=3,
357
+ bias=False
358
+ ),
359
+ nn.Tanh() if final_tanh else nn.Identity()
360
+ ]
361
+
362
+ self.layers = nn.Sequential(*layers)
363
+
364
+ def forward(self, x):
365
+ return self.layers(x)
366
+
367
+
368
+ class Bottleneck(nn.Module):
369
+ def __init__(self, is_discrete: bool = False):
370
+ super().__init__()
371
+
372
+ self.is_discrete = is_discrete
373
+
374
+ def encode(self, x, return_info=False, **kwargs):
375
+ raise NotImplementedError
376
+
377
+ def decode(self, x):
378
+ raise NotImplementedError
379
+
380
+
381
+ @torch.jit.script
382
+ def vae_sample(mean, scale) -> dict[str, torch.Tensor]:
383
+ stdev = nn.functional.softplus(scale) + 1e-4
384
+ var = stdev * stdev
385
+ logvar = torch.log(var)
386
+ latents = torch.randn_like(mean) * stdev + mean
387
+
388
+ kl = (mean * mean + var - logvar - 1).sum(1).mean()
389
+ return {"latents": latents, "kl": kl}
390
+
391
+
392
+ class VAEBottleneck(Bottleneck):
393
+ def __init__(self):
394
+ super().__init__(is_discrete=False)
395
+
396
+ def encode(self,
397
+ x,
398
+ return_info=False,
399
+ **kwargs) -> dict[str, torch.Tensor] | torch.Tensor:
400
+ mean, scale = x.chunk(2, dim=1)
401
+ sampled = vae_sample(mean, scale)
402
+
403
+ if return_info:
404
+ return sampled["latents"], {"kl": sampled["kl"]}
405
+ else:
406
+ return sampled["latents"]
407
+
408
+ def decode(self, x):
409
+ return x
410
+
411
+
412
+ def compute_mean_kernel(x, y):
413
+ kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1]
414
+ return torch.exp(-kernel_input).mean()
415
+
416
+
417
+ class Pretransform(nn.Module):
418
+ def __init__(self, enable_grad, io_channels, is_discrete):
419
+ super().__init__()
420
+
421
+ self.is_discrete = is_discrete
422
+ self.io_channels = io_channels
423
+ self.encoded_channels = None
424
+ self.downsampling_ratio = None
425
+
426
+ self.enable_grad = enable_grad
427
+
428
+ def encode(self, x):
429
+ raise NotImplementedError
430
+
431
+ def decode(self, z):
432
+ raise NotImplementedError
433
+
434
+ def tokenize(self, x):
435
+ raise NotImplementedError
436
+
437
+ def decode_tokens(self, tokens):
438
+ raise NotImplementedError
439
+
440
+
441
+ class StableVAE(LoadPretrainedBase, AutoEncoderBase):
442
+ def __init__(
443
+ self,
444
+ encoder,
445
+ decoder,
446
+ latent_dim,
447
+ downsampling_ratio,
448
+ sample_rate,
449
+ io_channels=2,
450
+ bottleneck: Bottleneck = None,
451
+ pretransform: Pretransform = None,
452
+ in_channels=None,
453
+ out_channels=None,
454
+ soft_clip=False,
455
+ pretrained_ckpt: str | Path = None
456
+ ):
457
+ LoadPretrainedBase.__init__(self)
458
+ AutoEncoderBase.__init__(
459
+ self,
460
+ downsampling_ratio=downsampling_ratio,
461
+ sample_rate=sample_rate,
462
+ latent_shape=(latent_dim, None)
463
+ )
464
+
465
+ self.latent_dim = latent_dim
466
+ self.io_channels = io_channels
467
+ self.in_channels = io_channels
468
+ self.out_channels = io_channels
469
+ self.min_length = self.downsampling_ratio
470
+
471
+ if in_channels is not None:
472
+ self.in_channels = in_channels
473
+
474
+ if out_channels is not None:
475
+ self.out_channels = out_channels
476
+
477
+ self.bottleneck = bottleneck
478
+ self.encoder = encoder
479
+ self.decoder = decoder
480
+ self.pretransform = pretransform
481
+ self.soft_clip = soft_clip
482
+ self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete
483
+
484
+ self.remove_autoencoder_prefix_fn: Callable = remove_key_prefix_factory(
485
+ "autoencoder."
486
+ )
487
+ if pretrained_ckpt is not None:
488
+ self.load_pretrained(pretrained_ckpt)
489
+
490
+ def process_state_dict(self, model_dict, state_dict):
491
+ state_dict = state_dict["state_dict"]
492
+ state_dict = self.remove_autoencoder_prefix_fn(model_dict, state_dict)
493
+ return state_dict
494
+
495
+ def encode(
496
+ self, waveform: torch.Tensor, waveform_lengths: torch.Tensor
497
+ ) -> tuple[torch.Tensor, torch.Tensor]:
498
+ z = self.encoder(waveform)
499
+ z = self.bottleneck.encode(z)
500
+ z_length = waveform_lengths // self.downsampling_ratio
501
+ z_mask = create_mask_from_length(z_length)
502
+ return z, z_mask
503
+
504
+ def decode(self, latents: torch.Tensor) -> torch.Tensor:
505
+ waveform = self.decoder(latents)
506
+ return waveform
507
+
508
+
509
+ class StableVAEProjectorWrapper(nn.Module):
510
+ def __init__(
511
+ self,
512
+ vae_dim: int,
513
+ embed_dim: int,
514
+ model: StableVAE | None = None,
515
+ ):
516
+ super().__init__()
517
+ self.model = model
518
+ self.proj = nn.Linear(vae_dim, embed_dim)
519
+
520
+ def forward(
521
+ self, waveform: torch.Tensor, waveform_lengths: torch.Tensor
522
+ ) -> tuple[torch.Tensor, torch.Tensor]:
523
+ self.model.eval()
524
+ with torch.no_grad():
525
+ z, z_mask = self.model.encode(waveform, waveform_lengths)
526
+ z = self.proj(z.transpose(1, 2))
527
+ return {"output": z, "mask": z_mask}
528
+
529
+
530
+ if __name__ == '__main__':
531
+ import hydra
532
+ from utils.config import generate_config_from_command_line_overrides
533
+ model_config = generate_config_from_command_line_overrides(
534
+ "configs/model/autoencoder/stable_vae.yaml"
535
+ )
536
+ autoencoder: StableVAE = hydra.utils.instantiate(model_config)
537
+ autoencoder.eval()
538
+
539
+ waveform, sr = torchaudio.load(
540
+ "/hpc_stor03/sjtu_home/xuenan.xu/data/m4singer/Tenor-1#童话/0006.wav"
541
+ )
542
+ waveform = waveform.mean(0, keepdim=True)
543
+ waveform = torchaudio.functional.resample(
544
+ waveform, sr, model_config["sample_rate"]
545
+ )
546
+ print("waveform: ", waveform.shape)
547
+ with torch.no_grad():
548
+ latent, latent_length = autoencoder.encode(
549
+ waveform, torch.as_tensor([waveform.shape[-1]])
550
+ )
551
+ print("latent: ", latent.shape)
552
+ reconstructed = autoencoder.decode(latent)
553
+ print("reconstructed: ", reconstructed.shape)
554
+ import soundfile as sf
555
+ sf.write(
556
+ "./reconstructed.wav",
557
+ reconstructed[0, 0].numpy(),
558
+ samplerate=model_config["sample_rate"]
559
+ )
models/common.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Sequence
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from utils.torch_utilities import (
8
+ load_pretrained_model, merge_matched_keys, create_mask_from_length,
9
+ loss_with_mask, create_alignment_path
10
+ )
11
+
12
+
13
+ class LoadPretrainedBase(nn.Module):
14
+ def process_state_dict(
15
+ self, model_dict: dict[str, torch.Tensor],
16
+ state_dict: dict[str, torch.Tensor]
17
+ ):
18
+ """
19
+ Custom processing functions of each model that transforms `state_dict` loaded from
20
+ checkpoints to the state that can be used in `load_state_dict`.
21
+ Use `merge_mathced_keys` to update parameters with matched names and shapes by
22
+ default.
23
+
24
+ Args
25
+ model_dict:
26
+ The state dict of the current model, which is going to load pretrained parameters
27
+ state_dict:
28
+ A dictionary of parameters from a pre-trained model.
29
+
30
+ Returns:
31
+ dict[str, torch.Tensor]:
32
+ The updated state dict, where parameters with matched keys and shape are
33
+ updated with values in `state_dict`.
34
+ """
35
+ state_dict = merge_matched_keys(model_dict, state_dict)
36
+ return state_dict
37
+
38
+ def load_pretrained(self, ckpt_path: str | Path):
39
+ load_pretrained_model(
40
+ self, ckpt_path, state_dict_process_fn=self.process_state_dict
41
+ )
42
+
43
+
44
+ class CountParamsBase(nn.Module):
45
+ def count_params(self):
46
+ num_params = 0
47
+ trainable_params = 0
48
+ for param in self.parameters():
49
+ num_params += param.numel()
50
+ if param.requires_grad:
51
+ trainable_params += param.numel()
52
+ return num_params, trainable_params
53
+
54
+
55
+ class SaveTrainableParamsBase(nn.Module):
56
+ @property
57
+ def param_names_to_save(self):
58
+ names = []
59
+ for name, param in self.named_parameters():
60
+ if param.requires_grad:
61
+ names.append(name)
62
+ for name, _ in self.named_buffers():
63
+ names.append(name)
64
+ return names
65
+
66
+ def load_state_dict(self, state_dict, strict=True):
67
+ for key in self.param_names_to_save:
68
+ if key not in state_dict:
69
+ raise Exception(
70
+ f"{key} not found in either pre-trained models (e.g. BERT)"
71
+ " or resumed checkpoints (e.g. epoch_40/model.pt)"
72
+ )
73
+ return super().load_state_dict(state_dict, strict)
74
+
75
+
76
+ class DurationAdapterMixin:
77
+ def __init__(
78
+ self,
79
+ latent_token_rate: int,
80
+ offset: float = 1.0,
81
+ frame_resolution: float | None = None
82
+ ):
83
+ self.latent_token_rate = latent_token_rate
84
+ self.offset = offset
85
+ self.frame_resolution = frame_resolution
86
+
87
+ def get_global_duration_loss(
88
+ self,
89
+ pred: torch.Tensor,
90
+ latent_mask: torch.Tensor,
91
+ reduce: bool = True,
92
+ ):
93
+ target = torch.log(
94
+ latent_mask.sum(1) / self.latent_token_rate + self.offset
95
+ )
96
+ loss = F.mse_loss(target, pred, reduction="mean" if reduce else "none")
97
+ return loss
98
+
99
+ def get_local_duration_loss(
100
+ self, ground_truth: torch.Tensor, pred: torch.Tensor,
101
+ mask: torch.Tensor, is_time_aligned: Sequence[bool], reduce: bool
102
+ ):
103
+ n_frames = torch.round(ground_truth / self.frame_resolution)
104
+ target = torch.log(n_frames + self.offset)
105
+ loss = loss_with_mask(
106
+ (target - pred)**2,
107
+ mask,
108
+ reduce=False,
109
+ )
110
+ loss *= is_time_aligned
111
+ if reduce:
112
+ if is_time_aligned.sum().item() == 0:
113
+ loss *= 0.0
114
+ loss = loss.mean()
115
+ else:
116
+ loss = loss.sum() / is_time_aligned.sum()
117
+
118
+ return loss
119
+
120
+ def prepare_local_duration(self, pred: torch.Tensor, mask: torch.Tensor):
121
+ pred = torch.exp(pred) * mask
122
+ pred = torch.ceil(pred) - self.offset
123
+ pred *= self.frame_resolution
124
+ return pred
125
+
126
+ def prepare_global_duration(
127
+ self,
128
+ global_pred: torch.Tensor,
129
+ local_pred: torch.Tensor,
130
+ is_time_aligned: Sequence[bool],
131
+ use_local: bool = True,
132
+ ):
133
+ """
134
+ global_pred: predicted duration value, processed by logarithmic and offset
135
+ local_pred: predicted latent length
136
+ """
137
+ global_pred = torch.exp(global_pred) - self.offset
138
+ result = global_pred
139
+ # avoid error accumulation for each frame
140
+ if use_local:
141
+ pred_from_local = torch.round(local_pred * self.latent_token_rate)
142
+ pred_from_local = pred_from_local.sum(1) / self.latent_token_rate
143
+ result[is_time_aligned] = pred_from_local[is_time_aligned]
144
+
145
+ return result
146
+
147
+ def expand_by_duration(
148
+ self,
149
+ x: torch.Tensor,
150
+ content_mask: torch.Tensor,
151
+ local_duration: torch.Tensor,
152
+ global_duration: torch.Tensor | None = None,
153
+ ):
154
+ n_latents = torch.round(local_duration * self.latent_token_rate)
155
+ if global_duration is not None:
156
+ latent_length = torch.round(
157
+ global_duration * self.latent_token_rate
158
+ )
159
+ else:
160
+ latent_length = n_latents.sum(1)
161
+ latent_mask = create_mask_from_length(latent_length).to(
162
+ content_mask.device
163
+ )
164
+ attn_mask = content_mask.unsqueeze(-1) * latent_mask.unsqueeze(1)
165
+ align_path = create_alignment_path(n_latents, attn_mask)
166
+ expanded_x = torch.matmul(align_path.transpose(1, 2).to(x.dtype), x)
167
+ return expanded_x, latent_mask
models/content_adapter.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Any
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from utils.torch_utilities import concat_non_padding, restore_from_concat, create_mask_from_length
7
+ from models.content_encoder.content_encoder import ContentEncoder
8
+
9
+
10
+ ######################
11
+ # fastspeech modules
12
+ ######################
13
+ class LayerNorm(nn.LayerNorm):
14
+ """Layer normalization module.
15
+ :param int nout: output dim size
16
+ :param int dim: dimension to be normalized
17
+ """
18
+ def __init__(self, nout, dim=-1):
19
+ """Construct an LayerNorm object."""
20
+ super(LayerNorm, self).__init__(nout, eps=1e-12)
21
+ self.dim = dim
22
+
23
+ def forward(self, x):
24
+ """Apply layer normalization.
25
+ :param torch.Tensor x: input tensor
26
+ :return: layer normalized tensor
27
+ :rtype torch.Tensor
28
+ """
29
+ if self.dim == -1:
30
+ return super(LayerNorm, self).forward(x)
31
+ return super(LayerNorm,
32
+ self).forward(x.transpose(1, -1)).transpose(1, -1)
33
+
34
+
35
+ class DurationPredictor(nn.Module):
36
+ def __init__(
37
+ self,
38
+ in_channels: int,
39
+ filter_channels: int,
40
+ n_layers: int = 2,
41
+ kernel_size: int = 3,
42
+ p_dropout: float = 0.1,
43
+ padding: str = "SAME"
44
+ ):
45
+ super(DurationPredictor, self).__init__()
46
+ self.conv = nn.ModuleList()
47
+ self.kernel_size = kernel_size
48
+ self.padding = padding
49
+ for idx in range(n_layers):
50
+ in_chans = in_channels if idx == 0 else filter_channels
51
+ self.conv += [
52
+ nn.Sequential(
53
+ nn.ConstantPad1d(((kernel_size - 1) // 2,
54
+ (kernel_size - 1) //
55
+ 2) if padding == 'SAME' else
56
+ (kernel_size - 1, 0), 0),
57
+ nn.Conv1d(
58
+ in_chans,
59
+ filter_channels,
60
+ kernel_size,
61
+ stride=1,
62
+ padding=0
63
+ ), nn.ReLU(), LayerNorm(filter_channels, dim=1),
64
+ nn.Dropout(p_dropout)
65
+ )
66
+ ]
67
+ self.linear = nn.Linear(filter_channels, 1)
68
+
69
+ def forward(self, x: torch.Tensor, x_mask: torch.Tensor):
70
+ # x: [B, T, E]
71
+ x = x.transpose(1, -1)
72
+ x_mask = x_mask.unsqueeze(1).to(x.device)
73
+ for f in self.conv:
74
+ x = f(x)
75
+ x = x * x_mask.float()
76
+
77
+ x = self.linear(x.transpose(1, -1)
78
+ ) * x_mask.transpose(1, -1).float() # [B, T, 1]
79
+ return x
80
+
81
+
82
+ ######################
83
+ # adapter modules
84
+ ######################
85
+
86
+
87
+ class ContentAdapterBase(nn.Module):
88
+ def __init__(self, d_out):
89
+ super().__init__()
90
+ self.d_out = d_out
91
+
92
+
93
+ class SinusoidalPositionalEmbedding(nn.Module):
94
+ def __init__(self, d_model, dropout, max_len=1000):
95
+ super().__init__()
96
+ self.dropout = nn.Dropout(dropout)
97
+ pe = torch.zeros(max_len, d_model)
98
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
99
+ div_term = torch.exp(
100
+ torch.arange(0, d_model, 2).float() *
101
+ (-math.log(10000.0) / d_model)
102
+ )
103
+ pe[:, 0::2] = torch.sin(position * div_term)
104
+ pe[:, 1::2] = torch.cos(position * div_term)
105
+ pe = pe.unsqueeze(0).transpose(0, 1)
106
+ self.register_buffer('pe', pe)
107
+
108
+ def forward(self, x):
109
+ x = x + self.pe[:x.size(1), :]
110
+ return self.dropout(x)
111
+
112
+
113
+ class ContentAdapter(ContentAdapterBase):
114
+ def __init__(
115
+ self,
116
+ d_model: int,
117
+ d_out: int,
118
+ num_layers: int,
119
+ num_heads: int,
120
+ duration_predictor: DurationPredictor,
121
+ dropout: float = 0.1,
122
+ norm_first: bool = False,
123
+ activation: str = "gelu",
124
+ duration_grad_scale: float = 0.0,
125
+ ):
126
+ super().__init__(d_out)
127
+ self.duration_grad_scale = duration_grad_scale
128
+ self.cls_embed = nn.Parameter(torch.randn(d_model))
129
+ if hasattr(torch, "npu") and torch.npu.is_available():
130
+ enable_nested_tensor = False
131
+ else:
132
+ enable_nested_tensor = True
133
+ encoder_layer = nn.TransformerEncoderLayer(
134
+ d_model=d_model,
135
+ nhead=num_heads,
136
+ dim_feedforward=4 * d_model,
137
+ dropout=dropout,
138
+ activation=activation,
139
+ norm_first=norm_first,
140
+ batch_first=True
141
+ )
142
+ self.encoder_layers = nn.TransformerEncoder(
143
+ encoder_layer=encoder_layer,
144
+ num_layers=num_layers,
145
+ enable_nested_tensor=enable_nested_tensor
146
+ )
147
+ self.duration_predictor = duration_predictor
148
+ self.content_proj = nn.Conv1d(d_model, d_out, 1)
149
+
150
+ def forward(self, x, x_mask):
151
+ batch_size = x.size(0)
152
+ cls_embed = self.cls_embed.reshape(1, -1).expand(batch_size, -1)
153
+ cls_embed = cls_embed.to(x.device).unsqueeze(1)
154
+ x = torch.cat([cls_embed, x], dim=1)
155
+
156
+ cls_mask = torch.ones(batch_size, 1).to(x_mask.device)
157
+ x_mask = torch.cat([cls_mask, x_mask], dim=1)
158
+ x = self.encoder_layers(x, src_key_padding_mask=~x_mask.bool())
159
+ x_grad_rescaled = x * self.duration_grad_scale + x.detach(
160
+ ) * (1 - self.duration_grad_scale)
161
+ duration = self.duration_predictor(x_grad_rescaled, x_mask).squeeze(-1)
162
+ content = self.content_proj(x.transpose(1, 2)).transpose(1, 2)
163
+ return content[:, 1:], x_mask[:, 1:], duration[:, 0], duration[:, 1:]
164
+
165
+
166
+ class PrefixAdapter(ContentAdapterBase):
167
+ def __init__(
168
+ self,
169
+ content_dim: int,
170
+ d_model: int,
171
+ d_out: int,
172
+ prefix_dim: int,
173
+ num_layers: int,
174
+ num_heads: int,
175
+ duration_predictor: DurationPredictor,
176
+ dropout: float = 0.1,
177
+ norm_first: bool = False,
178
+ use_last_norm: bool = True,
179
+ activation: str = "gelu",
180
+ duration_grad_scale: float = 0.1,
181
+ ):
182
+ super().__init__(d_out)
183
+ self.duration_grad_scale = duration_grad_scale
184
+ self.prefix_mlp = nn.Sequential(
185
+ nn.Linear(prefix_dim, d_model), nn.ReLU(), nn.Dropout(dropout),
186
+ nn.Linear(d_model, d_model)
187
+ )
188
+ self.content_mlp = nn.Sequential(
189
+ nn.Linear(content_dim, d_model), nn.ReLU(), nn.Dropout(dropout),
190
+ nn.Linear(d_model, d_model)
191
+ )
192
+ layer = nn.TransformerEncoderLayer(
193
+ d_model=d_model,
194
+ nhead=num_heads,
195
+ dim_feedforward=4 * d_model,
196
+ dropout=dropout,
197
+ activation=activation,
198
+ batch_first=True,
199
+ norm_first=norm_first
200
+ )
201
+ if hasattr(torch, "npu") and torch.npu.is_available():
202
+ enable_nested_tensor = False
203
+ else:
204
+ enable_nested_tensor = True
205
+ self.cls_embed = nn.Parameter(torch.randn(d_model))
206
+ # self.pos_embed = SinusoidalPositionalEmbedding(d_model, dropout)
207
+ self.layers = nn.TransformerEncoder(
208
+ encoder_layer=layer,
209
+ num_layers=num_layers,
210
+ enable_nested_tensor=enable_nested_tensor
211
+ )
212
+ self.use_last_norm = use_last_norm
213
+ if self.use_last_norm:
214
+ self.last_norm = nn.LayerNorm(d_model)
215
+ self.duration_predictor = duration_predictor
216
+ self.content_proj = nn.Conv1d(d_model, d_out, 1)
217
+ nn.init.normal_(self.cls_embed, 0., 0.02)
218
+ nn.init.xavier_uniform_(self.content_proj.weight)
219
+ nn.init.constant_(self.content_proj.bias, 0.)
220
+
221
+ def forward(self, content, content_mask, instruction, instruction_mask):
222
+ batch_size = content.size(0)
223
+ cls_embed = self.cls_embed.reshape(1, -1).expand(batch_size, -1)
224
+ cls_embed = cls_embed.to(content.device).unsqueeze(1)
225
+ content = self.content_mlp(content)
226
+ x = torch.cat([cls_embed, content], dim=1)
227
+ cls_mask = torch.ones(batch_size, 1,
228
+ dtype=bool).to(content_mask.device)
229
+ x_mask = torch.cat([cls_mask, content_mask], dim=1)
230
+
231
+ prefix = self.prefix_mlp(instruction)
232
+ seq, seq_mask, perm = concat_non_padding(
233
+ prefix, instruction_mask, x, x_mask
234
+ )
235
+ # seq = self.pos_embed(seq)
236
+ x = self.layers(seq, src_key_padding_mask=~seq_mask.bool())
237
+ if self.use_last_norm:
238
+ x = self.last_norm(x)
239
+ _, x = restore_from_concat(x, instruction_mask, x_mask, perm)
240
+
241
+ x_grad_rescaled = x * self.duration_grad_scale + x.detach(
242
+ ) * (1 - self.duration_grad_scale)
243
+ duration = self.duration_predictor(x_grad_rescaled, x_mask).squeeze(-1)
244
+ content = self.content_proj(x.transpose(1, 2)).transpose(1, 2)
245
+ return content[:, 1:], x_mask[:, 1:], duration[:, 0], duration[:, 1:]
246
+
247
+
248
+ class CrossAttentionAdapter(ContentAdapterBase):
249
+ def __init__(
250
+ self,
251
+ d_out: int,
252
+ content_dim: int,
253
+ prefix_dim: int,
254
+ num_heads: int,
255
+ duration_predictor: DurationPredictor,
256
+ dropout: float = 0.1,
257
+ duration_grad_scale: float = 0.1,
258
+ ):
259
+ super().__init__(d_out)
260
+ self.attn = nn.MultiheadAttention(
261
+ embed_dim=content_dim,
262
+ num_heads=num_heads,
263
+ dropout=dropout,
264
+ kdim=prefix_dim,
265
+ vdim=prefix_dim,
266
+ batch_first=True,
267
+ )
268
+ self.duration_grad_scale = duration_grad_scale
269
+ self.duration_predictor = duration_predictor
270
+ self.global_duration_mlp = nn.Sequential(
271
+ nn.Linear(content_dim, content_dim), nn.ReLU(),
272
+ nn.Dropout(dropout), nn.Linear(content_dim, 1)
273
+ )
274
+ self.norm = nn.LayerNorm(content_dim)
275
+ self.content_proj = nn.Conv1d(content_dim, d_out, 1)
276
+
277
+ def forward(self, content, content_mask, prefix, prefix_mask):
278
+ attn_output, attn_output_weights = self.attn(
279
+ query=content,
280
+ key=prefix,
281
+ value=prefix,
282
+ key_padding_mask=~prefix_mask.bool()
283
+ )
284
+ attn_output = attn_output * content_mask.unsqueeze(-1).float()
285
+ x = self.norm(attn_output + content)
286
+ x_grad_rescaled = x * self.duration_grad_scale + x.detach(
287
+ ) * (1 - self.duration_grad_scale)
288
+ x_aggregated = (x_grad_rescaled * content_mask.unsqueeze(-1).float()
289
+ ).sum(dim=1) / content_mask.sum(dim=1,
290
+ keepdim=True).float()
291
+ global_duration = self.global_duration_mlp(x_aggregated).squeeze(-1)
292
+ local_duration = self.duration_predictor(
293
+ x_grad_rescaled, content_mask
294
+ ).squeeze(-1)
295
+ content = self.content_proj(x.transpose(1, 2)).transpose(1, 2)
296
+ return content, content_mask, global_duration, local_duration
297
+
298
+
299
+ class ExperimentalCrossAttentionAdapter(ContentAdapterBase):
300
+ def __init__(
301
+ self,
302
+ d_out: int,
303
+ content_dim: int,
304
+ prefix_dim: int,
305
+ num_heads: int,
306
+ duration_predictor: DurationPredictor,
307
+ dropout: float = 0.1,
308
+ duration_grad_scale: float = 0.1,
309
+ ):
310
+ super().__init__(d_out)
311
+ self.content_mlp = nn.Sequential(
312
+ nn.Linear(content_dim, content_dim),
313
+ nn.ReLU(),
314
+ nn.Dropout(dropout),
315
+ nn.Linear(content_dim, content_dim),
316
+ )
317
+ self.content_norm = nn.LayerNorm(content_dim)
318
+ self.prefix_mlp = nn.Sequential(
319
+ nn.Linear(prefix_dim, prefix_dim),
320
+ nn.ReLU(),
321
+ nn.Dropout(dropout),
322
+ nn.Linear(prefix_dim, prefix_dim),
323
+ )
324
+ self.prefix_norm = nn.LayerNorm(content_dim)
325
+ self.attn = nn.MultiheadAttention(
326
+ embed_dim=content_dim,
327
+ num_heads=num_heads,
328
+ dropout=dropout,
329
+ kdim=prefix_dim,
330
+ vdim=prefix_dim,
331
+ batch_first=True,
332
+ )
333
+ self.duration_grad_scale = duration_grad_scale
334
+ self.duration_predictor = duration_predictor
335
+ self.global_duration_mlp = nn.Sequential(
336
+ nn.Linear(content_dim, content_dim), nn.ReLU(),
337
+ nn.Dropout(dropout), nn.Linear(content_dim, 1)
338
+ )
339
+ self.content_proj = nn.Sequential(
340
+ nn.Linear(content_dim, d_out),
341
+ nn.ReLU(),
342
+ nn.Dropout(dropout),
343
+ nn.Linear(d_out, d_out),
344
+ )
345
+ self.norm1 = nn.LayerNorm(content_dim)
346
+ self.norm2 = nn.LayerNorm(d_out)
347
+ self.init_weights()
348
+
349
+ def init_weights(self):
350
+ def _init_weights(module):
351
+ if isinstance(module, nn.Linear):
352
+ nn.init.xavier_uniform_(module.weight)
353
+ if module.bias is not None:
354
+ nn.init.constant_(module.bias, 0.)
355
+
356
+ self.apply(_init_weights)
357
+
358
+ def forward(self, content, content_mask, prefix, prefix_mask):
359
+ content = self.content_mlp(content)
360
+ content = self.content_norm(content)
361
+ prefix = self.prefix_mlp(prefix)
362
+ prefix = self.prefix_norm(prefix)
363
+ attn_output, attn_weights = self.attn(
364
+ query=content,
365
+ key=prefix,
366
+ value=prefix,
367
+ key_padding_mask=~prefix_mask.bool(),
368
+ )
369
+ attn_output = attn_output * content_mask.unsqueeze(-1).float()
370
+ x = attn_output + content
371
+ x = self.norm1(x)
372
+ x_grad_rescaled = x * self.duration_grad_scale + x.detach(
373
+ ) * (1 - self.duration_grad_scale)
374
+ x_aggregated = (x_grad_rescaled * content_mask.unsqueeze(-1).float()
375
+ ).sum(dim=1) / content_mask.sum(dim=1,
376
+ keepdim=True).float()
377
+ global_duration = self.global_duration_mlp(x_aggregated).squeeze(-1)
378
+ local_duration = self.duration_predictor(
379
+ x_grad_rescaled, content_mask
380
+ ).squeeze(-1)
381
+ content = self.content_proj(x)
382
+ content = self.norm2(content)
383
+ return content, content_mask, global_duration, local_duration
384
+
385
+
386
+ class ContentEncoderAdapterMixin:
387
+ def __init__(
388
+ self,
389
+ content_encoder: ContentEncoder,
390
+ content_adapter: ContentAdapterBase | None = None
391
+ ):
392
+ self.content_encoder = content_encoder
393
+ self.content_adapter = content_adapter
394
+
395
+ def encode_content(
396
+ self,
397
+ content: list[Any],
398
+ task: list[str],
399
+ device: str | torch.device,
400
+ instruction: torch.Tensor | None = None,
401
+ instruction_lengths: torch.Tensor | None = None
402
+ ):
403
+ content_output: dict[
404
+ str, torch.Tensor] = self.content_encoder.encode_content(
405
+ content, task, device=device
406
+ )
407
+ content, content_mask = content_output["content"], content_output[
408
+ "content_mask"]
409
+
410
+ if instruction is not None:
411
+ instruction_mask = create_mask_from_length(instruction_lengths)
412
+ (
413
+ content,
414
+ content_mask,
415
+ global_duration_pred,
416
+ local_duration_pred,
417
+ ) = self.content_adapter(
418
+ content, content_mask, instruction, instruction_mask
419
+ )
420
+
421
+ return_dict = {
422
+ "content": content,
423
+ "content_mask": content_mask,
424
+ "length_aligned_content": content_output["length_aligned_content"],
425
+ }
426
+ if instruction is not None:
427
+ return_dict["global_duration_pred"] = global_duration_pred
428
+ return_dict["local_duration_pred"] = local_duration_pred
429
+
430
+ return return_dict
models/content_encoder/content_encoder.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+
6
+ class ContentEncoder(nn.Module):
7
+ def __init__(
8
+ self,
9
+ embed_dim: int,
10
+ text_encoder: nn.Module = None,
11
+ video_encoder: nn.Module = None,
12
+ midi_encoder: nn.Module = None,
13
+ phoneme_encoder: nn.Module = None,
14
+ pitch_encoder: nn.Module = None,
15
+ audio_encoder: nn.Module = None
16
+ ):
17
+ super().__init__()
18
+ self.embed_dim = embed_dim
19
+ self.text_encoder = text_encoder
20
+ self.midi_encoder = midi_encoder
21
+ self.phoneme_encoder = phoneme_encoder
22
+ self.pitch_encoder = pitch_encoder
23
+ self.audio_encoder = audio_encoder
24
+ self.video_encoder = video_encoder
25
+
26
+ def encode_content(
27
+ self, batch_content: list[Any], batch_task: list[str],
28
+ device: str | torch.device
29
+ ):
30
+ batch_content_output = []
31
+ batch_content_mask = []
32
+ batch_la_content_output = []
33
+
34
+ zero_la_content = torch.zeros(1, 1, self.embed_dim, device=device)
35
+ for content, task in zip(batch_content, batch_task):
36
+ if task == "audio_super_resolution" or task == "speech_enhancement":
37
+ content_dict = {
38
+ "waveform": torch.as_tensor(content).float(),
39
+ "waveform_lengths": torch.as_tensor(content.shape[0]),
40
+ }
41
+ for key in list(content_dict.keys()):
42
+ content_dict[key] = content_dict[key].unsqueeze(0).to(
43
+ device
44
+ )
45
+ content_output_dict = self.audio_encoder(**content_dict)
46
+ la_content_output_dict = {
47
+ "output": zero_la_content,
48
+ }
49
+ elif task == "text_to_audio" or task == "text_to_music":
50
+ content_output_dict = self.text_encoder([content])
51
+ la_content_output_dict = {
52
+ "output": zero_la_content,
53
+ }
54
+ elif task == "video_to_audio":
55
+ content_dict = {
56
+ "frames": torch.as_tensor(content).float(),
57
+ "frame_nums": torch.as_tensor(content.shape[0]),
58
+ }
59
+ for key in list(content_dict.keys()):
60
+ content_dict[key] = content_dict[key].unsqueeze(0).to(
61
+ device
62
+ )
63
+ content_output_dict = self.video_encoder(**content_dict)
64
+ la_content_output_dict = {
65
+ "output": zero_la_content,
66
+ }
67
+ elif task == "singing_voice_synthesis":
68
+ content_dict = {
69
+ "phoneme":
70
+ torch.as_tensor(content["phoneme"]).long(),
71
+ "midi":
72
+ torch.as_tensor(content["midi"]).long(),
73
+ "midi_duration":
74
+ torch.as_tensor(content["midi_duration"]).float(),
75
+ "is_slur":
76
+ torch.as_tensor(content["is_slur"]).long()
77
+ }
78
+ if "spk" in content:
79
+ if self.midi_encoder.spk_config.encoding_format == "id":
80
+ content_dict["spk"] = torch.as_tensor(content["spk"]
81
+ ).long()
82
+ elif self.midi_encoder.spk_config.encoding_format == "embedding":
83
+ content_dict["spk"] = torch.as_tensor(content["spk"]
84
+ ).float()
85
+ for key in list(content_dict.keys()):
86
+ content_dict[key] = content_dict[key].unsqueeze(0).to(
87
+ device
88
+ )
89
+ content_dict["lengths"] = torch.as_tensor([
90
+ len(content["phoneme"])
91
+ ])
92
+ content_output_dict = self.midi_encoder(**content_dict)
93
+ la_content_output_dict = {"output": zero_la_content}
94
+ elif task == "text_to_speech":
95
+ content_dict = {
96
+ "phoneme": torch.as_tensor(content["phoneme"]).long(),
97
+ }
98
+ if "spk" in content:
99
+ if self.phoneme_encoder.spk_config.encoding_format == "id":
100
+ content_dict["spk"] = torch.as_tensor(content["spk"]
101
+ ).long()
102
+ elif self.phoneme_encoder.spk_config.encoding_format == "embedding":
103
+ content_dict["spk"] = torch.as_tensor(content["spk"]
104
+ ).float()
105
+ for key in list(content_dict.keys()):
106
+ content_dict[key] = content_dict[key].unsqueeze(0).to(
107
+ device
108
+ )
109
+ content_dict["lengths"] = torch.as_tensor([
110
+ len(content["phoneme"])
111
+ ])
112
+ content_output_dict = self.phoneme_encoder(**content_dict)
113
+ la_content_output_dict = {"output": zero_la_content}
114
+ elif task == "singing_acoustic_modeling":
115
+ content_dict = {
116
+ "phoneme": torch.as_tensor(content["phoneme"]).long(),
117
+ }
118
+ for key in list(content_dict.keys()):
119
+ content_dict[key] = content_dict[key].unsqueeze(0).to(
120
+ device
121
+ )
122
+ content_dict["lengths"] = torch.as_tensor([
123
+ len(content["phoneme"])
124
+ ])
125
+ content_output_dict = self.pitch_encoder(**content_dict)
126
+
127
+ content_dict = {
128
+ "f0": torch.as_tensor(content["f0"]),
129
+ "uv": torch.as_tensor(content["uv"]),
130
+ }
131
+ for key in list(content_dict.keys()):
132
+ content_dict[key] = content_dict[key].unsqueeze(0).to(
133
+ device
134
+ )
135
+ la_content_output_dict = self.pitch_encoder.encode_pitch(
136
+ **content_dict
137
+ )
138
+ else:
139
+ raise ValueError(f"Unsupported task: {task}")
140
+
141
+ batch_content_output.append(content_output_dict["output"][0])
142
+ batch_content_mask.append(content_output_dict["mask"][0])
143
+ batch_la_content_output.append(la_content_output_dict["output"][0])
144
+
145
+ batch_content_output = nn.utils.rnn.pad_sequence(
146
+ batch_content_output, batch_first=True, padding_value=0
147
+ )
148
+ batch_content_mask = nn.utils.rnn.pad_sequence(
149
+ batch_content_mask, batch_first=True, padding_value=False
150
+ )
151
+ batch_la_content_output = nn.utils.rnn.pad_sequence(
152
+ batch_la_content_output, batch_first=True, padding_value=0
153
+ )
154
+ return {
155
+ "content": batch_content_output,
156
+ "content_mask": batch_content_mask,
157
+ "length_aligned_content": batch_la_content_output,
158
+ }
159
+
160
+
161
+ class BatchedContentEncoder(ContentEncoder):
162
+ def encode_content(
163
+ self, batch_content: list | dict, batch_task: list[str],
164
+ device: str | torch.device
165
+ ):
166
+ task = batch_task[0]
167
+ zero_la_content = torch.zeros(1, 1, self.embed_dim, device=device)
168
+ if task == "audio_super_resolution" or task == "speech_enhancement":
169
+ content_dict = {
170
+ "waveform":
171
+ batch_content["content"].unsqueeze(1).float().to(device),
172
+ "waveform_lengths":
173
+ batch_content["content_lengths"].long().to(device),
174
+ }
175
+ content_output = self.audio_encoder(**content_dict)
176
+ la_content_output = zero_la_content
177
+ elif task == "text_to_audio":
178
+ content_output = self.text_encoder(batch_content)
179
+ la_content_output = zero_la_content
180
+ elif task == "video_to_audio":
181
+ content_dict = {
182
+ "frames":
183
+ batch_content["content"].float().to(device),
184
+ "frame_nums":
185
+ batch_content["content_lengths"].long().to(device),
186
+ }
187
+ content_output = self.video_encoder(**content_dict)
188
+ la_content_output = zero_la_content
189
+ elif task == "singing_voice_synthesis":
190
+ content_dict = {
191
+ "phoneme":
192
+ batch_content["phoneme"].long().to(device),
193
+ "midi":
194
+ batch_content["midi"].long().to(device),
195
+ "midi_duration":
196
+ batch_content["midi_duration"].float().to(device),
197
+ "is_slur":
198
+ batch_content["is_slur"].long().to(device),
199
+ "lengths":
200
+ batch_content["phoneme_lengths"].long().cpu(),
201
+ }
202
+ if "spk" in batch_content:
203
+ if self.midi_encoder.spk_config.encoding_format == "id":
204
+ content_dict["spk"] = batch_content["spk"].long(
205
+ ).to(device)
206
+ elif self.midi_encoder.spk_config.encoding_format == "embedding":
207
+ content_dict["spk"] = batch_content["spk"].float(
208
+ ).to(device)
209
+ content_output = self.midi_encoder(**content_dict)
210
+ la_content_output = zero_la_content
211
+ elif task == "text_to_speech":
212
+ content_dict = {
213
+ "phoneme": batch_content["phoneme"].long().to(device),
214
+ "lengths": batch_content["phoneme_lengths"].long().cpu(),
215
+ }
216
+ if "spk" in batch_content:
217
+ if self.phoneme_encoder.spk_config.encoding_format == "id":
218
+ content_dict["spk"] = batch_content["spk"].long(
219
+ ).to(device)
220
+ elif self.phoneme_encoder.spk_config.encoding_format == "embedding":
221
+ content_dict["spk"] = batch_content["spk"].float(
222
+ ).to(device)
223
+ content_output = self.phoneme_encoder(**content_dict)
224
+ la_content_output = zero_la_content
225
+ elif task == "singing_acoustic_modeling":
226
+ content_dict = {
227
+ "phoneme": batch_content["phoneme"].long().to(device),
228
+ "lengths": batch_content["phoneme_lengths"].long().to(device),
229
+ }
230
+ content_output = self.pitch_encoder(**content_dict)
231
+
232
+ content_dict = {
233
+ "f0": batch_content["f0"].float().to(device),
234
+ "uv": batch_content["uv"].float().to(device),
235
+ }
236
+ la_content_output = self.pitch_encoder.encode_pitch(**content_dict)
237
+ else:
238
+ raise ValueError(f"Unsupported task: {task}")
239
+
240
+ return {
241
+ "content": content_output["output"],
242
+ "content_mask": content_output["mask"],
243
+ "length_aligned_content": la_content_output,
244
+ }
models/content_encoder/midi_encoder.py ADDED
@@ -0,0 +1,1046 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Sequence
2
+ from dataclasses import dataclass
3
+ import math
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torch.nn import Parameter
8
+
9
+ from utils.torch_utilities import create_mask_from_length
10
+ from utils.diffsinger_utilities import denorm_f0, f0_to_coarse
11
+
12
+
13
+ def make_positions(tensor, padding_idx):
14
+ """Replace non-padding symbols with their position numbers.
15
+ Position numbers begin at padding_idx+1. Padding symbols are ignored.
16
+ """
17
+ # The series of casts and type-conversions here are carefully
18
+ # balanced to both work with ONNX export and XLA. In particular XLA
19
+ # prefers ints, cumsum defaults to output longs, and ONNX doesn't know
20
+ # how to handle the dtype kwarg in cumsum.
21
+ mask = tensor.ne(padding_idx).int()
22
+ return (torch.cumsum(mask, dim=1).type_as(mask) *
23
+ mask).long() + padding_idx
24
+
25
+
26
+ def softmax(x, dim):
27
+ return F.softmax(x, dim=dim, dtype=torch.float32)
28
+
29
+
30
+ def LayerNorm(
31
+ normalized_shape, eps=1e-5, elementwise_affine=True, export=False
32
+ ):
33
+ if not export and torch.cuda.is_available():
34
+ try:
35
+ from apex.normalization import FusedLayerNorm
36
+ return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
37
+ except ImportError:
38
+ pass
39
+ return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
40
+
41
+
42
+ def Linear(in_features, out_features, bias=True):
43
+ m = nn.Linear(in_features, out_features, bias)
44
+ nn.init.xavier_uniform_(m.weight)
45
+ if bias:
46
+ nn.init.constant_(m.bias, 0.)
47
+ return m
48
+
49
+
50
+ def Embedding(num_embeddings, embedding_dim, padding_idx=None):
51
+ m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
52
+ nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5)
53
+ if padding_idx is not None:
54
+ nn.init.constant_(m.weight[padding_idx], 0)
55
+ return m
56
+
57
+
58
+ class BatchNorm1dTBC(nn.Module):
59
+ def __init__(self, c):
60
+ super(BatchNorm1dTBC, self).__init__()
61
+ self.bn = nn.BatchNorm1d(c)
62
+
63
+ def forward(self, x):
64
+ """
65
+
66
+ :param x: [T, B, C]
67
+ :return: [T, B, C]
68
+ """
69
+ x = x.permute(1, 2, 0) # [B, C, T]
70
+ x = self.bn(x) # [B, C, T]
71
+ x = x.permute(2, 0, 1) # [T, B, C]
72
+ return x
73
+
74
+
75
+ class PositionalEncoding(nn.Module):
76
+ """Positional encoding.
77
+ Args:
78
+ d_model (int): Embedding dimension.
79
+ dropout_rate (float): Dropout rate.
80
+ max_len (int): Maximum input length.
81
+ reverse (bool): Whether to reverse the input position.
82
+ """
83
+ def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
84
+ """Construct an PositionalEncoding object."""
85
+ super(PositionalEncoding, self).__init__()
86
+ self.d_model = d_model
87
+ self.reverse = reverse
88
+ self.xscale = math.sqrt(self.d_model)
89
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
90
+ self.pe = None
91
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
92
+
93
+ def extend_pe(self, x):
94
+ """Reset the positional encodings."""
95
+ if self.pe is not None:
96
+ if self.pe.size(1) >= x.size(1):
97
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
98
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
99
+ return
100
+ pe = torch.zeros(x.size(1), self.d_model)
101
+ if self.reverse:
102
+ position = torch.arange(
103
+ x.size(1) - 1, -1, -1.0, dtype=torch.float32
104
+ ).unsqueeze(1)
105
+ else:
106
+ position = torch.arange(0, x.size(1),
107
+ dtype=torch.float32).unsqueeze(1)
108
+ div_term = torch.exp(
109
+ torch.arange(0, self.d_model, 2, dtype=torch.float32) *
110
+ -(math.log(10000.0) / self.d_model)
111
+ )
112
+ pe[:, 0::2] = torch.sin(position * div_term)
113
+ pe[:, 1::2] = torch.cos(position * div_term)
114
+ pe = pe.unsqueeze(0)
115
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
116
+
117
+ def forward(self, x: torch.Tensor):
118
+ """Add positional encoding.
119
+ Args:
120
+ x (torch.Tensor): Input tensor (batch, time, `*`).
121
+ Returns:
122
+ torch.Tensor: Encoded tensor (batch, time, `*`).
123
+ """
124
+ self.extend_pe(x)
125
+ x = x * self.xscale + self.pe[:, :x.size(1)]
126
+ return self.dropout(x)
127
+
128
+
129
+ class SinusoidalPositionalEmbedding(nn.Module):
130
+ """This module produces sinusoidal positional embeddings of any length.
131
+
132
+ Padding symbols are ignored.
133
+ """
134
+ def __init__(self, d_model, padding_idx, init_size=2048):
135
+ super().__init__()
136
+ self.d_model = d_model
137
+ self.padding_idx = padding_idx
138
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
139
+ init_size,
140
+ d_model,
141
+ padding_idx,
142
+ )
143
+ self.register_buffer('_float_tensor', torch.FloatTensor(1))
144
+
145
+ @staticmethod
146
+ def get_embedding(num_embeddings, d_model, padding_idx=None):
147
+ """Build sinusoidal embeddings.
148
+
149
+ This matches the implementation in tensor2tensor, but differs slightly
150
+ from the description in Section 3.5 of "Attention Is All You Need".
151
+ """
152
+ half_dim = d_model // 2
153
+ emb = math.log(10000) / (half_dim - 1)
154
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
155
+ emb = torch.arange(num_embeddings,
156
+ dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
157
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)],
158
+ dim=1).view(num_embeddings, -1)
159
+ if d_model % 2 == 1:
160
+ # zero pad
161
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
162
+ if padding_idx is not None:
163
+ emb[padding_idx, :] = 0
164
+ return emb
165
+
166
+ def forward(
167
+ self,
168
+ x,
169
+ lengths,
170
+ incremental_state=None,
171
+ timestep=None,
172
+ positions=None,
173
+ **kwargs
174
+ ):
175
+ """Input is expected to be of size [bsz x seqlen]."""
176
+ bsz, seq_len = x.shape[:2]
177
+ max_pos = self.padding_idx + 1 + seq_len
178
+ if self.weights is None or max_pos > self.weights.size(0):
179
+ # recompute/expand embeddings if needed
180
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
181
+ max_pos,
182
+ self.d_model,
183
+ self.padding_idx,
184
+ )
185
+ self.weights = self.weights.to(self._float_tensor)
186
+
187
+ if incremental_state is not None:
188
+ # positions is the same for every token when decoding a single step
189
+ pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
190
+ return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
191
+
192
+ positions = create_mask_from_length(
193
+ lengths, max_length=x.shape[1]
194
+ ) * (torch.arange(x.shape[1]) + 1).unsqueeze(0).expand(x.shape[0], -1)
195
+ positions = positions.to(self.weights.device)
196
+ pos_emb = self.weights.index_select(0, positions.view(-1)).view(
197
+ bsz, seq_len, -1
198
+ ).detach()
199
+ return x + pos_emb
200
+
201
+ def max_positions(self):
202
+ """Maximum number of supported positions."""
203
+ return int(1e5) # an arbitrary large number
204
+
205
+
206
+ class RelPositionalEncoding(PositionalEncoding):
207
+ """Relative positional encoding module.
208
+ See : Appendix B in https://arxiv.org/abs/1901.02860
209
+ Args:
210
+ d_model (int): Embedding dimension.
211
+ dropout_rate (float): Dropout rate.
212
+ max_len (int): Maximum input length.
213
+ """
214
+ def __init__(self, d_model, dropout_rate, max_len=5000):
215
+ """Initialize class."""
216
+ super().__init__(d_model, dropout_rate, max_len, reverse=True)
217
+
218
+ def forward(self, x, lengths):
219
+ """Compute positional encoding.
220
+ Args:
221
+ x (torch.Tensor): Input tensor (batch, time, `*`).
222
+ Returns:
223
+ torch.Tensor: Encoded tensor (batch, time, `*`).
224
+ torch.Tensor: Positional embedding tensor (1, time, `*`).
225
+ """
226
+ self.extend_pe(x)
227
+ x = x * self.xscale
228
+ pos_emb = self.pe[:, :x.size(1)]
229
+ return self.dropout(x) + self.dropout(pos_emb)
230
+
231
+
232
+ class MultiheadAttention(nn.Module):
233
+ def __init__(
234
+ self,
235
+ embed_dim,
236
+ num_heads,
237
+ kdim=None,
238
+ vdim=None,
239
+ dropout=0.,
240
+ bias=True,
241
+ add_bias_kv=False,
242
+ add_zero_attn=False,
243
+ self_attention=False,
244
+ encoder_decoder_attention=False
245
+ ):
246
+ super().__init__()
247
+ self.embed_dim = embed_dim
248
+ self.kdim = kdim if kdim is not None else embed_dim
249
+ self.vdim = vdim if vdim is not None else embed_dim
250
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
251
+
252
+ self.num_heads = num_heads
253
+ self.dropout = dropout
254
+ self.head_dim = embed_dim // num_heads
255
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
256
+ self.scaling = self.head_dim**-0.5
257
+
258
+ self.self_attention = self_attention
259
+ self.encoder_decoder_attention = encoder_decoder_attention
260
+
261
+ assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \
262
+ 'value to be of the same size'
263
+
264
+ if self.qkv_same_dim:
265
+ self.in_proj_weight = Parameter(
266
+ torch.Tensor(3 * embed_dim, embed_dim)
267
+ )
268
+ else:
269
+ self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
270
+ self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
271
+ self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
272
+
273
+ if bias:
274
+ self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
275
+ else:
276
+ self.register_parameter('in_proj_bias', None)
277
+
278
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
279
+
280
+ if add_bias_kv:
281
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
282
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
283
+ else:
284
+ self.bias_k = self.bias_v = None
285
+
286
+ self.add_zero_attn = add_zero_attn
287
+
288
+ self.reset_parameters()
289
+
290
+ self.enable_torch_version = False
291
+ if hasattr(F, "multi_head_attention_forward"):
292
+ self.enable_torch_version = True
293
+ else:
294
+ self.enable_torch_version = False
295
+ self.last_attn_probs = None
296
+
297
+ def reset_parameters(self):
298
+ if self.qkv_same_dim:
299
+ nn.init.xavier_uniform_(self.in_proj_weight)
300
+ else:
301
+ nn.init.xavier_uniform_(self.k_proj_weight)
302
+ nn.init.xavier_uniform_(self.v_proj_weight)
303
+ nn.init.xavier_uniform_(self.q_proj_weight)
304
+
305
+ nn.init.xavier_uniform_(self.out_proj.weight)
306
+ if self.in_proj_bias is not None:
307
+ nn.init.constant_(self.in_proj_bias, 0.)
308
+ nn.init.constant_(self.out_proj.bias, 0.)
309
+ if self.bias_k is not None:
310
+ nn.init.xavier_normal_(self.bias_k)
311
+ if self.bias_v is not None:
312
+ nn.init.xavier_normal_(self.bias_v)
313
+
314
+ def forward(
315
+ self,
316
+ query,
317
+ key,
318
+ value,
319
+ key_padding_mask=None,
320
+ incremental_state=None,
321
+ need_weights=True,
322
+ static_kv=False,
323
+ attn_mask=None,
324
+ before_softmax=False,
325
+ need_head_weights=False,
326
+ enc_dec_attn_constraint_mask=None,
327
+ reset_attn_weight=None
328
+ ):
329
+ """Input shape: Time x Batch x Channel
330
+
331
+ Args:
332
+ key_padding_mask (ByteTensor, optional): mask to exclude
333
+ keys that are pads, of shape `(batch, src_len)`, where
334
+ padding elements are indicated by 1s.
335
+ need_weights (bool, optional): return the attention weights,
336
+ averaged over heads (default: False).
337
+ attn_mask (ByteTensor, optional): typically used to
338
+ implement causal attention, where the mask prevents the
339
+ attention from looking forward in time (default: None).
340
+ before_softmax (bool, optional): return the raw attention
341
+ weights and values before the attention softmax.
342
+ need_head_weights (bool, optional): return the attention
343
+ weights for each head. Implies *need_weights*. Default:
344
+ return the average attention weights over all heads.
345
+ """
346
+ if need_head_weights:
347
+ need_weights = True
348
+
349
+ tgt_len, bsz, embed_dim = query.size()
350
+ assert embed_dim == self.embed_dim
351
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
352
+
353
+ if self.enable_torch_version and incremental_state is None and not static_kv and reset_attn_weight is None:
354
+ if self.qkv_same_dim:
355
+ return F.multi_head_attention_forward(
356
+ query, key, value, self.embed_dim, self.num_heads,
357
+ self.in_proj_weight, self.in_proj_bias, self.bias_k,
358
+ self.bias_v, self.add_zero_attn, self.dropout,
359
+ self.out_proj.weight, self.out_proj.bias, self.training,
360
+ key_padding_mask, need_weights, attn_mask
361
+ )
362
+ else:
363
+ return F.multi_head_attention_forward(
364
+ query,
365
+ key,
366
+ value,
367
+ self.embed_dim,
368
+ self.num_heads,
369
+ torch.empty([0]),
370
+ self.in_proj_bias,
371
+ self.bias_k,
372
+ self.bias_v,
373
+ self.add_zero_attn,
374
+ self.dropout,
375
+ self.out_proj.weight,
376
+ self.out_proj.bias,
377
+ self.training,
378
+ key_padding_mask,
379
+ need_weights,
380
+ attn_mask,
381
+ use_separate_proj_weight=True,
382
+ q_proj_weight=self.q_proj_weight,
383
+ k_proj_weight=self.k_proj_weight,
384
+ v_proj_weight=self.v_proj_weight
385
+ )
386
+
387
+ if incremental_state is not None:
388
+ print('Not implemented error.')
389
+ exit()
390
+ else:
391
+ saved_state = None
392
+
393
+ if self.self_attention:
394
+ # self-attention
395
+ q, k, v = self.in_proj_qkv(query)
396
+ elif self.encoder_decoder_attention:
397
+ # encoder-decoder attention
398
+ q = self.in_proj_q(query)
399
+ if key is None:
400
+ assert value is None
401
+ k = v = None
402
+ else:
403
+ k = self.in_proj_k(key)
404
+ v = self.in_proj_v(key)
405
+
406
+ else:
407
+ q = self.in_proj_q(query)
408
+ k = self.in_proj_k(key)
409
+ v = self.in_proj_v(value)
410
+ q *= self.scaling
411
+
412
+ if self.bias_k is not None:
413
+ assert self.bias_v is not None
414
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
415
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
416
+ if attn_mask is not None:
417
+ attn_mask = torch.cat(
418
+ [attn_mask,
419
+ attn_mask.new_zeros(attn_mask.size(0), 1)],
420
+ dim=1
421
+ )
422
+ if key_padding_mask is not None:
423
+ key_padding_mask = torch.cat(
424
+ [
425
+ key_padding_mask,
426
+ key_padding_mask.new_zeros(
427
+ key_padding_mask.size(0), 1
428
+ )
429
+ ],
430
+ dim=1
431
+ )
432
+
433
+ q = q.contiguous().view(tgt_len, bsz * self.num_heads,
434
+ self.head_dim).transpose(0, 1)
435
+ if k is not None:
436
+ k = k.contiguous().view(-1, bsz * self.num_heads,
437
+ self.head_dim).transpose(0, 1)
438
+ if v is not None:
439
+ v = v.contiguous().view(-1, bsz * self.num_heads,
440
+ self.head_dim).transpose(0, 1)
441
+
442
+ if saved_state is not None:
443
+ print('Not implemented error.')
444
+ exit()
445
+
446
+ src_len = k.size(1)
447
+
448
+ # This is part of a workaround to get around fork/join parallelism
449
+ # not supporting Optional types.
450
+ if key_padding_mask is not None and key_padding_mask.shape == torch.Size(
451
+ []
452
+ ):
453
+ key_padding_mask = None
454
+
455
+ if key_padding_mask is not None:
456
+ assert key_padding_mask.size(0) == bsz
457
+ assert key_padding_mask.size(1) == src_len
458
+
459
+ if self.add_zero_attn:
460
+ src_len += 1
461
+ k = torch.cat(
462
+ [k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1
463
+ )
464
+ v = torch.cat(
465
+ [v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1
466
+ )
467
+ if attn_mask is not None:
468
+ attn_mask = torch.cat(
469
+ [attn_mask,
470
+ attn_mask.new_zeros(attn_mask.size(0), 1)],
471
+ dim=1
472
+ )
473
+ if key_padding_mask is not None:
474
+ key_padding_mask = torch.cat(
475
+ [
476
+ key_padding_mask,
477
+ torch.zeros(key_padding_mask.size(0),
478
+ 1).type_as(key_padding_mask)
479
+ ],
480
+ dim=1
481
+ )
482
+
483
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
484
+ attn_weights = self.apply_sparse_mask(
485
+ attn_weights, tgt_len, src_len, bsz
486
+ )
487
+
488
+ assert list(attn_weights.size()) == [
489
+ bsz * self.num_heads, tgt_len, src_len
490
+ ]
491
+
492
+ if attn_mask is not None:
493
+ if len(attn_mask.shape) == 2:
494
+ attn_mask = attn_mask.unsqueeze(0)
495
+ elif len(attn_mask.shape) == 3:
496
+ attn_mask = attn_mask[:, None].repeat(
497
+ [1, self.num_heads, 1, 1]
498
+ ).reshape(bsz * self.num_heads, tgt_len, src_len)
499
+ attn_weights = attn_weights + attn_mask
500
+
501
+ if enc_dec_attn_constraint_mask is not None: # bs x head x L_kv
502
+ attn_weights = attn_weights.view(
503
+ bsz, self.num_heads, tgt_len, src_len
504
+ )
505
+ attn_weights = attn_weights.masked_fill(
506
+ enc_dec_attn_constraint_mask.unsqueeze(2).bool(),
507
+ -1e9,
508
+ )
509
+ attn_weights = attn_weights.view(
510
+ bsz * self.num_heads, tgt_len, src_len
511
+ )
512
+
513
+ if key_padding_mask is not None:
514
+ # don't attend to padding symbols
515
+ attn_weights = attn_weights.view(
516
+ bsz, self.num_heads, tgt_len, src_len
517
+ )
518
+ attn_weights = attn_weights.masked_fill(
519
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
520
+ -1e9,
521
+ )
522
+ attn_weights = attn_weights.view(
523
+ bsz * self.num_heads, tgt_len, src_len
524
+ )
525
+
526
+ attn_logits = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
527
+
528
+ if before_softmax:
529
+ return attn_weights, v
530
+
531
+ attn_weights_float = softmax(attn_weights, dim=-1)
532
+ attn_weights = attn_weights_float.type_as(attn_weights)
533
+ attn_probs = F.dropout(
534
+ attn_weights_float.type_as(attn_weights),
535
+ p=self.dropout,
536
+ training=self.training
537
+ )
538
+
539
+ if reset_attn_weight is not None:
540
+ if reset_attn_weight:
541
+ self.last_attn_probs = attn_probs.detach()
542
+ else:
543
+ assert self.last_attn_probs is not None
544
+ attn_probs = self.last_attn_probs
545
+ attn = torch.bmm(attn_probs, v)
546
+ assert list(attn.size()) == [
547
+ bsz * self.num_heads, tgt_len, self.head_dim
548
+ ]
549
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
550
+ attn = self.out_proj(attn)
551
+
552
+ if need_weights:
553
+ attn_weights = attn_weights_float.view(
554
+ bsz, self.num_heads, tgt_len, src_len
555
+ ).transpose(1, 0)
556
+ if not need_head_weights:
557
+ # average attention weights over heads
558
+ attn_weights = attn_weights.mean(dim=0)
559
+ else:
560
+ attn_weights = None
561
+
562
+ return attn, (attn_weights, attn_logits)
563
+
564
+ def in_proj_qkv(self, query):
565
+ return self._in_proj(query).chunk(3, dim=-1)
566
+
567
+ def in_proj_q(self, query):
568
+ if self.qkv_same_dim:
569
+ return self._in_proj(query, end=self.embed_dim)
570
+ else:
571
+ bias = self.in_proj_bias
572
+ if bias is not None:
573
+ bias = bias[:self.embed_dim]
574
+ return F.linear(query, self.q_proj_weight, bias)
575
+
576
+ def in_proj_k(self, key):
577
+ if self.qkv_same_dim:
578
+ return self._in_proj(
579
+ key, start=self.embed_dim, end=2 * self.embed_dim
580
+ )
581
+ else:
582
+ weight = self.k_proj_weight
583
+ bias = self.in_proj_bias
584
+ if bias is not None:
585
+ bias = bias[self.embed_dim:2 * self.embed_dim]
586
+ return F.linear(key, weight, bias)
587
+
588
+ def in_proj_v(self, value):
589
+ if self.qkv_same_dim:
590
+ return self._in_proj(value, start=2 * self.embed_dim)
591
+ else:
592
+ weight = self.v_proj_weight
593
+ bias = self.in_proj_bias
594
+ if bias is not None:
595
+ bias = bias[2 * self.embed_dim:]
596
+ return F.linear(value, weight, bias)
597
+
598
+ def _in_proj(self, input, start=0, end=None):
599
+ weight = self.in_proj_weight
600
+ bias = self.in_proj_bias
601
+ weight = weight[start:end, :]
602
+ if bias is not None:
603
+ bias = bias[start:end]
604
+ return F.linear(input, weight, bias)
605
+
606
+ def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz):
607
+ return attn_weights
608
+
609
+
610
+ class TransformerFFNLayer(nn.Module):
611
+ def __init__(
612
+ self,
613
+ hidden_size,
614
+ filter_size,
615
+ padding="SAME",
616
+ kernel_size=1,
617
+ dropout=0.,
618
+ act='gelu'
619
+ ):
620
+ super().__init__()
621
+ self.kernel_size = kernel_size
622
+ self.dropout = dropout
623
+ self.act = act
624
+ if padding == 'SAME':
625
+ self.ffn_1 = nn.Conv1d(
626
+ hidden_size,
627
+ filter_size,
628
+ kernel_size,
629
+ padding=kernel_size // 2
630
+ )
631
+ elif padding == 'LEFT':
632
+ self.ffn_1 = nn.Sequential(
633
+ nn.ConstantPad1d((kernel_size - 1, 0), 0.0),
634
+ nn.Conv1d(hidden_size, filter_size, kernel_size)
635
+ )
636
+ self.ffn_2 = nn.Linear(filter_size, hidden_size)
637
+
638
+ def forward(
639
+ self,
640
+ x,
641
+ ):
642
+ # x: T x B x C
643
+ x = self.ffn_1(x.permute(1, 2, 0)).permute(2, 0, 1)
644
+ x = x * self.kernel_size**-0.5
645
+
646
+ if self.act == 'gelu':
647
+ x = F.gelu(x)
648
+ if self.act == 'relu':
649
+ x = F.relu(x)
650
+ if self.act == 'swish':
651
+ x = F.silu(x)
652
+ x = F.dropout(x, self.dropout, training=self.training)
653
+ x = self.ffn_2(x)
654
+ return x
655
+
656
+
657
+ class EncoderSelfAttentionLayer(nn.Module):
658
+ def __init__(
659
+ self,
660
+ c,
661
+ num_heads,
662
+ dropout,
663
+ attention_dropout=0.1,
664
+ relu_dropout=0.1,
665
+ kernel_size=9,
666
+ padding='SAME',
667
+ norm='ln',
668
+ act='gelu',
669
+ padding_set_zero=True
670
+ ):
671
+ super().__init__()
672
+ self.c = c
673
+ self.dropout = dropout
674
+ self.num_heads = num_heads
675
+ self.padding_set_zero = padding_set_zero
676
+ if num_heads > 0:
677
+ if norm == 'ln':
678
+ self.layer_norm1 = LayerNorm(c)
679
+ elif norm == 'bn':
680
+ self.layer_norm1 = BatchNorm1dTBC(c)
681
+ self.self_attn = MultiheadAttention(
682
+ self.c,
683
+ num_heads=num_heads,
684
+ self_attention=True,
685
+ dropout=attention_dropout,
686
+ bias=False,
687
+ )
688
+ if norm == 'ln':
689
+ self.layer_norm2 = LayerNorm(c)
690
+ elif norm == 'bn':
691
+ self.layer_norm2 = BatchNorm1dTBC(c)
692
+ self.ffn = TransformerFFNLayer(
693
+ c,
694
+ 4 * c,
695
+ kernel_size=kernel_size,
696
+ dropout=relu_dropout,
697
+ padding=padding,
698
+ act=act
699
+ )
700
+
701
+ def forward(self, x, encoder_padding_mask=None, **kwargs):
702
+ layer_norm_training = kwargs.get('layer_norm_training', None)
703
+ if layer_norm_training is not None:
704
+ self.layer_norm1.training = layer_norm_training
705
+ self.layer_norm2.training = layer_norm_training
706
+ if self.num_heads > 0:
707
+ residual = x
708
+ x = self.layer_norm1(x)
709
+ x, _, = self.self_attn(
710
+ query=x, key=x, value=x, key_padding_mask=encoder_padding_mask
711
+ )
712
+ x = F.dropout(x, self.dropout, training=self.training)
713
+ x = residual + x
714
+ if self.padding_set_zero:
715
+ x = x * (1 - encoder_padding_mask.float()).transpose(0,
716
+ 1)[...,
717
+ None]
718
+
719
+ residual = x
720
+ x = self.layer_norm2(x)
721
+ x = self.ffn(x)
722
+ x = F.dropout(x, self.dropout, training=self.training)
723
+ x = residual + x
724
+ if self.padding_set_zero:
725
+ x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[...,
726
+ None]
727
+ return x
728
+
729
+
730
+ class TransformerEncoderLayer(nn.Module):
731
+ def __init__(
732
+ self,
733
+ hidden_size,
734
+ dropout,
735
+ kernel_size,
736
+ num_heads=2,
737
+ norm='ln',
738
+ padding_set_zero=True,
739
+ ):
740
+ super().__init__()
741
+ self.hidden_size = hidden_size
742
+ self.dropout = dropout
743
+ self.num_heads = num_heads
744
+ self.op = EncoderSelfAttentionLayer(
745
+ hidden_size,
746
+ num_heads,
747
+ dropout=dropout,
748
+ attention_dropout=0.0,
749
+ relu_dropout=dropout,
750
+ kernel_size=kernel_size,
751
+ padding="SAME",
752
+ norm=norm,
753
+ act="gelu",
754
+ padding_set_zero=padding_set_zero
755
+ )
756
+
757
+ def forward(self, x, **kwargs):
758
+ return self.op(x, **kwargs)
759
+
760
+
761
+ class FFTBlocks(nn.Module):
762
+ def __init__(
763
+ self,
764
+ hidden_size,
765
+ num_layers,
766
+ ffn_kernel_size=9,
767
+ dropout=0.1,
768
+ num_heads=2,
769
+ use_last_norm=True,
770
+ padding_set_zero=True,
771
+ ):
772
+ super().__init__()
773
+ self.num_layers = num_layers
774
+ embed_dim = self.hidden_size = hidden_size
775
+ self.dropout = dropout
776
+ self.use_last_norm = use_last_norm
777
+ self.padding_set_zero = padding_set_zero
778
+
779
+ self.layers = nn.ModuleList([])
780
+ self.layers.extend(
781
+ [
782
+ TransformerEncoderLayer(
783
+ self.hidden_size,
784
+ self.dropout,
785
+ kernel_size=ffn_kernel_size,
786
+ num_heads=num_heads,
787
+ padding_set_zero=padding_set_zero,
788
+ ) for _ in range(self.num_layers)
789
+ ]
790
+ )
791
+ if self.use_last_norm:
792
+ self.layer_norm = nn.LayerNorm(embed_dim)
793
+ else:
794
+ self.layer_norm = None
795
+
796
+ def forward(self, x, padding_mask=None, attn_mask=None):
797
+ """
798
+ :param x: [B, T, C]
799
+ :param padding_mask: [B, T]
800
+ :return: [B, T, C] or [L, B, T, C]
801
+ """
802
+ if padding_mask is None:
803
+ padding_mask = torch.zeros(x.size(0), x.size(1)).to(x.device)
804
+ nonpadding_mask_TB = 1 - padding_mask.transpose(0, 1).float(
805
+ )[:, :, None] # [T, B, 1]
806
+ # B x T x C -> T x B x C
807
+ x = x.transpose(0, 1)
808
+ if self.padding_set_zero:
809
+ x = x * nonpadding_mask_TB
810
+ for layer in self.layers:
811
+ x = layer(
812
+ x, encoder_padding_mask=padding_mask, attn_mask=attn_mask
813
+ )
814
+ if self.padding_set_zero:
815
+ x = x * nonpadding_mask_TB
816
+ if self.use_last_norm:
817
+ x = self.layer_norm(x)
818
+ if self.padding_set_zero:
819
+ x = x * nonpadding_mask_TB
820
+
821
+ x = x.transpose(0, 1) # [B, T, C]
822
+ return x
823
+
824
+
825
+ class FastSpeech2EncoderBase(nn.Module):
826
+ def __init__(
827
+ self,
828
+ d_model: int,
829
+ num_layers: int,
830
+ num_heads: int,
831
+ ffn_kernel_size: int,
832
+ d_out: int,
833
+ dropout: float = 0.1,
834
+ rel_pos: bool = True,
835
+ padding_set_zero: bool = True
836
+ ):
837
+ super().__init__()
838
+ self.rel_pos = rel_pos
839
+
840
+ if self.rel_pos:
841
+ self.pos_encoding = RelPositionalEncoding(
842
+ d_model, dropout_rate=0.0
843
+ )
844
+ else:
845
+ self.pos_encoding = SinusoidalPositionalEmbedding(
846
+ d_model, padding_idx=0
847
+ )
848
+ self.dropout = dropout
849
+ self.embed_scale = math.sqrt(d_model)
850
+
851
+ self.layers = FFTBlocks(
852
+ hidden_size=d_model,
853
+ num_layers=num_layers,
854
+ ffn_kernel_size=ffn_kernel_size,
855
+ dropout=dropout,
856
+ num_heads=num_heads,
857
+ use_last_norm=True,
858
+ padding_set_zero=padding_set_zero
859
+ )
860
+
861
+ self.out_proj = nn.Linear(d_model, d_out)
862
+ self.apply(self.init_weights)
863
+
864
+ def init_weights(self, m):
865
+ if isinstance(m, nn.Linear):
866
+ nn.init.xavier_uniform_(m.weight)
867
+ if m.bias is not None:
868
+ nn.init.constant_(m.bias, 0.)
869
+ elif isinstance(m, nn.Embedding):
870
+ nn.init.normal_(m.weight, mean=0, std=m.embedding_dim**-0.5)
871
+
872
+
873
+ @dataclass
874
+ class SpkConfig:
875
+ encoding_format: str
876
+ num_spk: int | None = None
877
+ spk_embed_dim: int | None = None
878
+
879
+ def __post_init__(self):
880
+ allowed_formats = {"id", "embedding"}
881
+ assert self.encoding_format in allowed_formats, f"mode must be one of {allowed_formats}, got '{self.encoding_format}'"
882
+ if self.encoding_format == "id":
883
+ assert self.num_spk is not None
884
+ if self.encoding_format == "embedding":
885
+ assert self.spk_embed_dim is not None
886
+
887
+
888
+ class FastSpeech2PhonemeEncoder(FastSpeech2EncoderBase):
889
+ def __init__(
890
+ self,
891
+ phone_vocab_size,
892
+ d_model,
893
+ num_layers,
894
+ num_heads,
895
+ ffn_kernel_size,
896
+ d_out,
897
+ dropout=0.1,
898
+ rel_pos=False,
899
+ spk_config: SpkConfig | None = None,
900
+ padding_set_zero: bool = True
901
+ ):
902
+ super().__init__(
903
+ d_model=d_model,
904
+ num_layers=num_layers,
905
+ num_heads=num_heads,
906
+ ffn_kernel_size=ffn_kernel_size,
907
+ d_out=d_out,
908
+ dropout=dropout,
909
+ rel_pos=rel_pos,
910
+ padding_set_zero=padding_set_zero
911
+ )
912
+ self.phone_embed = Embedding(phone_vocab_size, d_model)
913
+ self.spk_config = spk_config
914
+ if spk_config is not None:
915
+ if spk_config.encoding_format == "id":
916
+ self.spk_embed_proj = Embedding(
917
+ spk_config.num_spk + 1, d_model
918
+ )
919
+ elif spk_config.encoding_format == "embedding":
920
+ self.spk_embed_proj = Linear(spk_config.spk_embed_dim, d_model)
921
+
922
+ def forward(
923
+ self, phoneme: torch.Tensor, lengths: Sequence[int], spk: torch.Tensor
924
+ ):
925
+ x = self.embed_scale * self.phone_embed(phoneme)
926
+ x = self.pos_encoding(x, lengths)
927
+ x = F.dropout(x, p=self.dropout, training=self.training)
928
+
929
+ padding_mask = ~create_mask_from_length(lengths).to(phoneme.device)
930
+ x = self.layers(x, padding_mask=padding_mask)
931
+
932
+ if self.spk_config is not None:
933
+ spk_embed = self.spk_embed_proj(spk).unsqueeze(1)
934
+ x = x + spk_embed
935
+
936
+ x = self.out_proj(x)
937
+
938
+ return {"output": x, "mask": ~padding_mask}
939
+
940
+
941
+ class FastSpeech2MIDIEncoder(FastSpeech2PhonemeEncoder):
942
+ def __init__(
943
+ self,
944
+ phone_vocab_size: int,
945
+ midi_vocab_size: int,
946
+ slur_vocab_size: int,
947
+ spk_config: SpkConfig | None,
948
+ d_model: int,
949
+ num_layers: int,
950
+ num_heads: int,
951
+ ffn_kernel_size: int,
952
+ d_out: int,
953
+ dropout: float = 0.1,
954
+ rel_pos: bool = True,
955
+ padding_set_zero: bool = True
956
+ ):
957
+ super().__init__(
958
+ phone_vocab_size=phone_vocab_size,
959
+ d_model=d_model,
960
+ num_layers=num_layers,
961
+ num_heads=num_heads,
962
+ ffn_kernel_size=ffn_kernel_size,
963
+ d_out=d_out,
964
+ dropout=dropout,
965
+ rel_pos=rel_pos,
966
+ spk_config=spk_config,
967
+ padding_set_zero=padding_set_zero
968
+ )
969
+ self.midi_embed = Embedding(midi_vocab_size, d_model, padding_idx=0)
970
+ self.midi_dur_embed = Linear(1, d_model)
971
+ self.is_slur_embed = Embedding(slur_vocab_size, d_model)
972
+
973
+ def forward(
974
+ self,
975
+ phoneme: torch.Tensor,
976
+ midi: torch.Tensor,
977
+ midi_duration: torch.Tensor,
978
+ is_slur: torch.Tensor,
979
+ lengths: Sequence[int],
980
+ spk: torch.Tensor | None = None,
981
+ ):
982
+ x = self.embed_scale * self.phone_embed(phoneme)
983
+ midi_embedding = self.midi_embed(midi)
984
+ midi_dur_embedding = self.midi_dur_embed(midi_duration[:, :, None])
985
+ slur_embedding = self.is_slur_embed(is_slur)
986
+
987
+ x = x + midi_embedding + midi_dur_embedding + slur_embedding
988
+ x = self.pos_encoding(x, lengths)
989
+ x = F.dropout(x, p=self.dropout, training=self.training)
990
+
991
+ padding_mask = ~create_mask_from_length(lengths).to(phoneme.device)
992
+ x = self.layers(x, padding_mask=padding_mask)
993
+
994
+ if self.spk_config is not None:
995
+ spk_embed = self.spk_embed_proj(spk).unsqueeze(1)
996
+ x = x + spk_embed
997
+
998
+ x = self.out_proj(x)
999
+
1000
+ return {"output": x, "mask": ~padding_mask}
1001
+
1002
+
1003
+ class FastSpeech2PitchEncoder(FastSpeech2EncoderBase):
1004
+ def __init__(
1005
+ self,
1006
+ phone_vocab_size,
1007
+ d_model,
1008
+ num_layers,
1009
+ num_heads,
1010
+ ffn_kernel_size,
1011
+ d_out,
1012
+ dropout=0.1,
1013
+ rel_pos=False,
1014
+ padding_set_zero=True
1015
+ ):
1016
+ super().__init__(
1017
+ d_model=d_model,
1018
+ num_layers=num_layers,
1019
+ num_heads=num_heads,
1020
+ ffn_kernel_size=ffn_kernel_size,
1021
+ d_out=d_out,
1022
+ dropout=dropout,
1023
+ rel_pos=rel_pos,
1024
+ padding_set_zero=padding_set_zero
1025
+ )
1026
+ self.phone_embed = Embedding(phone_vocab_size, d_model)
1027
+ self.pitch_embed = Embedding(300, d_model)
1028
+
1029
+ def forward(self, phoneme: torch.Tensor, lengths: Sequence[int]):
1030
+ x = self.embed_scale * self.phone_embed(phoneme)
1031
+ x = self.pos_encoding(x, lengths)
1032
+ x = F.dropout(x, p=self.dropout, training=self.training)
1033
+
1034
+ padding_mask = ~create_mask_from_length(lengths).to(phoneme.device)
1035
+ x = self.layers(x, padding_mask=padding_mask)
1036
+
1037
+ x = self.out_proj(x)
1038
+
1039
+ return {"output": x, "mask": ~padding_mask}
1040
+
1041
+ def encode_pitch(self, f0, uv):
1042
+
1043
+ f0_denorm = denorm_f0(f0, uv)
1044
+ pitch = f0_to_coarse(f0_denorm)
1045
+ pitch_embed = self.pitch_embed(pitch)
1046
+ return {"output": pitch_embed}
models/content_encoder/text_encoder.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import AutoTokenizer, AutoModel, T5Tokenizer, T5EncoderModel
4
+ from transformers.modeling_outputs import BaseModelOutput
5
+
6
+ try:
7
+ import torch_npu
8
+ from torch_npu.contrib import transfer_to_npu
9
+ DEVICE_TYPE = "npu"
10
+ except ModuleNotFoundError:
11
+ DEVICE_TYPE = "cuda"
12
+
13
+
14
+ class TransformersTextEncoderBase(nn.Module):
15
+ def __init__(self, model_name: str, embed_dim: int):
16
+ super().__init__()
17
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
18
+ self.model = AutoModel.from_pretrained(model_name)
19
+ self.proj = nn.Linear(self.model.config.hidden_size, embed_dim)
20
+
21
+ def forward(
22
+ self,
23
+ text: list[str],
24
+ ):
25
+ output, mask = self.encode(text)
26
+ output = self.projection(output)
27
+ return {"output": output, "mask": mask}
28
+
29
+ def encode(self, text: list[str]):
30
+ device = self.model.device
31
+ batch = self.tokenizer(
32
+ text,
33
+ max_length=self.tokenizer.model_max_length,
34
+ padding=True,
35
+ truncation=True,
36
+ return_tensors="pt",
37
+ )
38
+ input_ids = batch.input_ids.to(device)
39
+ attention_mask = batch.attention_mask.to(device)
40
+ output: BaseModelOutput = self.model(
41
+ input_ids=input_ids, attention_mask=attention_mask
42
+ )
43
+ output = output.last_hidden_state
44
+ mask = (attention_mask == 1).to(device)
45
+ return output, mask
46
+
47
+ def projection(self, x):
48
+ return self.proj(x)
49
+
50
+
51
+ class T5TextEncoder(TransformersTextEncoderBase):
52
+ def __init__(
53
+ self, embed_dim: int, model_name: str = "google/flan-t5-large"
54
+ ):
55
+ nn.Module.__init__(self)
56
+ self.tokenizer = T5Tokenizer.from_pretrained(model_name)
57
+ self.model = T5EncoderModel.from_pretrained(model_name)
58
+ for param in self.model.parameters():
59
+ param.requires_grad = False
60
+ self.model.eval()
61
+ self.proj = nn.Linear(self.model.config.hidden_size, embed_dim)
62
+
63
+ def encode(
64
+ self,
65
+ text: list[str],
66
+ ):
67
+ with torch.no_grad(), torch.amp.autocast(
68
+ device_type=DEVICE_TYPE, enabled=False
69
+ ):
70
+ return super().encode(text)
71
+
72
+
73
+ if __name__ == "__main__":
74
+ text_encoder = T5TextEncoder(embed_dim=512)
75
+ text = ["a man is speaking", "a woman is singing while a dog is barking"]
76
+
77
+ output = text_encoder(text)
models/content_encoder/vision_encoder.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Sequence
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from utils.torch_utilities import create_mask_from_length
8
+
9
+
10
+ class MlpVideoEncoder(nn.Module):
11
+ def __init__(
12
+ self,
13
+ video_feat_dim: int,
14
+ embed_dim: int,
15
+ ):
16
+ super().__init__()
17
+ self.mlp = nn.Linear(video_feat_dim, embed_dim)
18
+ self.init_weights()
19
+
20
+ def init_weights(self):
21
+ def _init_weights(module):
22
+ if isinstance(module, nn.Linear):
23
+ nn.init.xavier_uniform_(module.weight)
24
+ if module.bias is not None:
25
+ nn.init.constant_(module.bias, 0.)
26
+
27
+ self.apply(_init_weights)
28
+
29
+ def forward(self, frames: torch.Tensor, frame_nums: Sequence[int]):
30
+ device = frames.device
31
+ x = F.normalize(frames, p=2, dim=-1)
32
+ x = self.mlp(x)
33
+ mask = create_mask_from_length(frame_nums).to(device)
34
+ return {"output": x, "mask": mask}
models/diffusion.py ADDED
@@ -0,0 +1,943 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Sequence
2
+ import random
3
+ from typing import Any
4
+
5
+ from tqdm import tqdm
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import diffusers.schedulers as noise_schedulers
10
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
11
+ from diffusers.utils.torch_utils import randn_tensor
12
+
13
+ from models.autoencoder.autoencoder_base import AutoEncoderBase
14
+ from models.content_encoder.content_encoder import ContentEncoder
15
+ from models.content_adapter import ContentAdapterBase, ContentEncoderAdapterMixin
16
+ from models.common import (
17
+ LoadPretrainedBase, CountParamsBase, SaveTrainableParamsBase,
18
+ DurationAdapterMixin
19
+ )
20
+ from utils.torch_utilities import (
21
+ create_alignment_path, create_mask_from_length, loss_with_mask,
22
+ trim_or_pad_length
23
+ )
24
+
25
+
26
+ class DiffusionMixin:
27
+ def __init__(
28
+ self,
29
+ noise_scheduler_name: str = "stabilityai/stable-diffusion-2-1",
30
+ snr_gamma: float = None,
31
+ cfg_drop_ratio: float = 0.2
32
+ ) -> None:
33
+ self.noise_scheduler_name = noise_scheduler_name
34
+ self.snr_gamma = snr_gamma
35
+ self.classifier_free_guidance = cfg_drop_ratio > 0.0
36
+ self.cfg_drop_ratio = cfg_drop_ratio
37
+ self.noise_scheduler = noise_schedulers.DDPMScheduler.from_pretrained(
38
+ self.noise_scheduler_name, subfolder="scheduler"
39
+ )
40
+
41
+ def compute_snr(self, timesteps) -> torch.Tensor:
42
+ """
43
+ Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
44
+ """
45
+ alphas_cumprod = self.noise_scheduler.alphas_cumprod
46
+ sqrt_alphas_cumprod = alphas_cumprod**0.5
47
+ sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod)**0.5
48
+
49
+ # Expand the tensors.
50
+ # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
51
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device
52
+ )[timesteps].float()
53
+ while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
54
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
55
+ alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
56
+
57
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(
58
+ device=timesteps.device
59
+ )[timesteps].float()
60
+ while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
61
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[...,
62
+ None]
63
+ sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
64
+
65
+ # Compute SNR.
66
+ snr = (alpha / sigma)**2
67
+ return snr
68
+
69
+ def get_timesteps(
70
+ self,
71
+ batch_size: int,
72
+ device: torch.device,
73
+ training: bool = True
74
+ ) -> torch.Tensor:
75
+ if training:
76
+ timesteps = torch.randint(
77
+ 0,
78
+ self.noise_scheduler.config.num_train_timesteps,
79
+ (batch_size, ),
80
+ device=device
81
+ )
82
+ else:
83
+ # validation on half of the total timesteps
84
+ timesteps = (self.noise_scheduler.config.num_train_timesteps //
85
+ 2) * torch.ones((batch_size, ),
86
+ dtype=torch.int64,
87
+ device=device)
88
+
89
+ timesteps = timesteps.long()
90
+ return timesteps
91
+
92
+ def get_input_target_and_timesteps(
93
+ self,
94
+ latent: torch.Tensor,
95
+ training: bool,
96
+ ):
97
+ batch_size = latent.shape[0]
98
+ device = latent.device
99
+ num_train_timesteps = self.noise_scheduler.config.num_train_timesteps
100
+ self.noise_scheduler.set_timesteps(num_train_timesteps, device=device)
101
+ timesteps = self.get_timesteps(batch_size, device, training=training)
102
+ noise = torch.randn_like(latent)
103
+ noisy_latent = self.noise_scheduler.add_noise(latent, noise, timesteps)
104
+ target = self.get_target(latent, noise, timesteps)
105
+ return noisy_latent, target, timesteps
106
+
107
+ def get_target(
108
+ self, latent: torch.Tensor, noise: torch.Tensor,
109
+ timesteps: torch.Tensor
110
+ ) -> torch.Tensor:
111
+ """
112
+ Get the target for loss depending on the prediction type
113
+ """
114
+ if self.noise_scheduler.config.prediction_type == "epsilon":
115
+ target = noise
116
+ elif self.noise_scheduler.config.prediction_type == "v_prediction":
117
+ target = self.noise_scheduler.get_velocity(
118
+ latent, noise, timesteps
119
+ )
120
+ else:
121
+ raise ValueError(
122
+ f"Unknown prediction type {self.noise_scheduler.config.prediction_type}"
123
+ )
124
+ return target
125
+
126
+ def loss_with_snr(
127
+ self,
128
+ pred: torch.Tensor,
129
+ target: torch.Tensor,
130
+ timesteps: torch.Tensor,
131
+ mask: torch.Tensor,
132
+ reduce: bool = True
133
+ ) -> torch.Tensor:
134
+ if self.snr_gamma is None:
135
+ loss = F.mse_loss(pred.float(), target.float(), reduction="none")
136
+ loss = loss_with_mask(loss, mask, reduce=reduce)
137
+ else:
138
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
139
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py#L1006
140
+ snr = self.compute_snr(timesteps)
141
+ mse_loss_weights = torch.stack(
142
+ [
143
+ snr,
144
+ self.snr_gamma * torch.ones_like(timesteps),
145
+ ],
146
+ dim=1,
147
+ ).min(dim=1)[0]
148
+ # division by (snr + 1) does not work well, not clear about the reason
149
+ mse_loss_weights = mse_loss_weights / snr
150
+ loss = F.mse_loss(pred.float(), target.float(), reduction="none")
151
+ loss = loss_with_mask(loss, mask, reduce=False) * mse_loss_weights
152
+ if reduce:
153
+ loss = loss.mean()
154
+ return loss
155
+
156
+ def rescale_cfg(
157
+ self, pred_cond: torch.Tensor, pred_cfg: torch.Tensor,
158
+ guidance_rescale: float
159
+ ):
160
+ """
161
+ Rescale `pred_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
162
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
163
+ """
164
+ std_cond = pred_cond.std(
165
+ dim=list(range(1, pred_cond.ndim)), keepdim=True
166
+ )
167
+ std_cfg = pred_cfg.std(dim=list(range(1, pred_cfg.ndim)), keepdim=True)
168
+
169
+ pred_rescaled = pred_cfg * (std_cond / std_cfg)
170
+ pred_cfg = guidance_rescale * pred_rescaled + (
171
+ 1 - guidance_rescale
172
+ ) * pred_cfg
173
+ return pred_cfg
174
+
175
+
176
+ class SingleTaskCrossAttentionAudioDiffusion(
177
+ LoadPretrainedBase, CountParamsBase, SaveTrainableParamsBase,
178
+ DiffusionMixin, ContentEncoderAdapterMixin
179
+ ):
180
+ def __init__(
181
+ self,
182
+ autoencoder: AutoEncoderBase,
183
+ content_encoder: ContentEncoder,
184
+ backbone: nn.Module,
185
+ noise_scheduler_name: str = "stabilityai/stable-diffusion-2-1",
186
+ snr_gamma: float = None,
187
+ cfg_drop_ratio: float = 0.2,
188
+ ):
189
+ nn.Module.__init__(self)
190
+ DiffusionMixin.__init__(
191
+ self, noise_scheduler_name, snr_gamma, cfg_drop_ratio
192
+ )
193
+ ContentEncoderAdapterMixin.__init__(
194
+ self, content_encoder=content_encoder
195
+ )
196
+
197
+ self.autoencoder = autoencoder
198
+ for param in self.autoencoder.parameters():
199
+ param.requires_grad = False
200
+
201
+ if hasattr(self.content_encoder, "audio_encoder"):
202
+ self.content_encoder.audio_encoder.model = self.autoencoder
203
+
204
+ self.backbone = backbone
205
+ self.dummy_param = nn.Parameter(torch.empty(0))
206
+
207
+ def forward(
208
+ self, content: list[Any], condition: list[Any], task: list[str],
209
+ waveform: torch.Tensor, waveform_lengths: torch.Tensor, **kwargs
210
+ ):
211
+ device = self.dummy_param.device
212
+
213
+ self.autoencoder.eval()
214
+ with torch.no_grad():
215
+ latent, latent_mask = self.autoencoder.encode(
216
+ waveform.unsqueeze(1), waveform_lengths
217
+ )
218
+
219
+ content_dict = self.encode_content(content, task, device)
220
+ content, content_mask = content_dict["content"], content_dict[
221
+ "content_mask"]
222
+
223
+ if self.training and self.classifier_free_guidance:
224
+ mask_indices = [
225
+ k for k in range(len(waveform))
226
+ if random.random() < self.cfg_drop_ratio
227
+ ]
228
+ if len(mask_indices) > 0:
229
+ content[mask_indices] = 0
230
+
231
+ noisy_latent, target, timesteps = self.get_input_target_and_timesteps(
232
+ latent, self.training
233
+ )
234
+
235
+ pred: torch.Tensor = self.backbone(
236
+ x=noisy_latent,
237
+ timesteps=timesteps,
238
+ context=content,
239
+ x_mask=latent_mask,
240
+ context_mask=content_mask
241
+ )
242
+
243
+ pred = pred.transpose(1, self.autoencoder.time_dim)
244
+ target = target.transpose(1, self.autoencoder.time_dim)
245
+ loss = self.loss_with_snr(pred, target, timesteps, latent_mask)
246
+
247
+ return loss
248
+
249
+ def prepare_latent(
250
+ self, batch_size: int, scheduler: SchedulerMixin,
251
+ latent_shape: Sequence[int], dtype: torch.dtype, device: str
252
+ ):
253
+ shape = (batch_size, *latent_shape)
254
+ latent = randn_tensor(
255
+ shape, generator=None, device=device, dtype=dtype
256
+ )
257
+ # scale the initial noise by the standard deviation required by the scheduler
258
+ latent = latent * scheduler.init_noise_sigma
259
+ return latent
260
+
261
+ def iterative_denoise(
262
+ self,
263
+ latent: torch.Tensor,
264
+ scheduler: SchedulerMixin,
265
+ verbose: bool,
266
+ cfg: bool,
267
+ cfg_scale: float,
268
+ cfg_rescale: float,
269
+ backbone_input: dict,
270
+ ):
271
+ timesteps = scheduler.timesteps
272
+ num_steps = len(timesteps)
273
+ num_warmup_steps = len(timesteps) - num_steps * scheduler.order
274
+ progress_bar = tqdm(range(num_steps), disable=not verbose)
275
+
276
+ for i, timestep in enumerate(timesteps):
277
+ # expand the latent if we are doing classifier free guidance
278
+ if cfg:
279
+ latent_input = torch.cat([latent, latent])
280
+ else:
281
+ latent_input = latent
282
+ latent_input = scheduler.scale_model_input(latent_input, timestep)
283
+
284
+ noise_pred = self.backbone(
285
+ x=latent_input, timesteps=timestep, **backbone_input
286
+ )
287
+
288
+ # perform guidance
289
+ if cfg:
290
+ noise_pred_uncond, noise_pred_content = noise_pred.chunk(2)
291
+ noise_pred = noise_pred_uncond + cfg_scale * (
292
+ noise_pred_content - noise_pred_uncond
293
+ )
294
+ if cfg_rescale != 0.0:
295
+ noise_pred = self.rescale_cfg(
296
+ noise_pred_content, noise_pred, cfg_rescale
297
+ )
298
+
299
+ # compute the previous noisy sample x_t -> x_t-1
300
+ latent = scheduler.step(noise_pred, timestep, latent).prev_sample
301
+
302
+ # call the callback, if provided
303
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and
304
+ (i + 1) % scheduler.order == 0):
305
+ progress_bar.update(1)
306
+
307
+ progress_bar.close()
308
+
309
+ return latent
310
+
311
+ @torch.no_grad()
312
+ def inference(
313
+ self,
314
+ content: list[Any],
315
+ condition: list[Any],
316
+ task: list[str],
317
+ latent_shape: Sequence[int],
318
+ scheduler: SchedulerMixin,
319
+ num_steps: int = 50,
320
+ guidance_scale: float = 3.0,
321
+ guidance_rescale: float = 0.0,
322
+ disable_progress: bool = True,
323
+ **kwargs
324
+ ):
325
+ device = self.dummy_param.device
326
+ classifier_free_guidance = guidance_scale > 1.0
327
+ batch_size = len(content)
328
+
329
+ content_output: dict[str, torch.Tensor] = self.encode_content(
330
+ content, task, device
331
+ )
332
+ content, content_mask = content_output["content"], content_output[
333
+ "content_mask"]
334
+
335
+ if classifier_free_guidance:
336
+ uncond_content = torch.zeros_like(content)
337
+ uncond_content_mask = content_mask.detach().clone()
338
+ content = torch.cat([uncond_content, content])
339
+ content_mask = torch.cat([uncond_content_mask, content_mask])
340
+
341
+ scheduler.set_timesteps(num_steps, device=device)
342
+
343
+ latent = self.prepare_latent(
344
+ batch_size, scheduler, latent_shape, content.dtype, device
345
+ )
346
+ latent = self.iterative_denoise(
347
+ latent=latent,
348
+ scheduler=scheduler,
349
+ verbose=not disable_progress,
350
+ cfg=classifier_free_guidance,
351
+ cfg_scale=guidance_scale,
352
+ cfg_rescale=guidance_rescale,
353
+ backbone_input={
354
+ "context": content,
355
+ "context_mask": content_mask
356
+ },
357
+ )
358
+
359
+ waveform = self.autoencoder.decode(latent)
360
+
361
+ return waveform
362
+
363
+
364
+ class CrossAttentionAudioDiffusion(
365
+ SingleTaskCrossAttentionAudioDiffusion, DurationAdapterMixin
366
+ ):
367
+ def __init__(
368
+ self,
369
+ autoencoder: AutoEncoderBase,
370
+ content_encoder: ContentEncoder,
371
+ content_adapter: ContentAdapterBase,
372
+ backbone: nn.Module,
373
+ content_dim: int = None,
374
+ frame_resolution: float = None,
375
+ duration_offset: float = 1.0,
376
+ noise_scheduler_name: str = "stabilityai/stable-diffusion-2-1",
377
+ snr_gamma: float = None,
378
+ cfg_drop_ratio: float = 0.2,
379
+ ):
380
+ super().__init__(
381
+ autoencoder=autoencoder,
382
+ content_encoder=content_encoder,
383
+ backbone=backbone,
384
+ noise_scheduler_name=noise_scheduler_name,
385
+ snr_gamma=snr_gamma,
386
+ cfg_drop_ratio=cfg_drop_ratio
387
+ )
388
+ ContentEncoderAdapterMixin.__init__(
389
+ self,
390
+ content_encoder=content_encoder,
391
+ content_adapter=content_adapter,
392
+ )
393
+ DurationAdapterMixin.__init__(
394
+ self,
395
+ latent_token_rate=autoencoder.latent_token_rate,
396
+ offset=duration_offset,
397
+ )
398
+
399
+ def encode_content_with_instruction(
400
+ self,
401
+ content: list[Any],
402
+ task: list[str],
403
+ device: str | torch.device,
404
+ instruction: torch.Tensor,
405
+ instruction_lengths: torch.Tensor,
406
+ ):
407
+ content_dict = self.encode_content(
408
+ content, task, device, instruction, instruction_lengths
409
+ )
410
+ return (
411
+ content_dict["content"],
412
+ content_dict["content_mask"],
413
+ content_dict["global_duration_pred"],
414
+ content_dict["local_duration_pred"],
415
+ content_dict["length_aligned_content"],
416
+ )
417
+
418
+ def forward(
419
+ self,
420
+ content: list[Any],
421
+ task: list[str],
422
+ waveform: torch.Tensor,
423
+ waveform_lengths: torch.Tensor,
424
+ instruction: torch.Tensor,
425
+ instruction_lengths: Sequence[int],
426
+ loss_reduce: bool = True,
427
+ **kwargs
428
+ ):
429
+ device = self.dummy_param.device
430
+ loss_reduce = self.training or (loss_reduce and not self.training)
431
+
432
+ self.autoencoder.eval()
433
+ with torch.no_grad():
434
+ latent, latent_mask = self.autoencoder.encode(
435
+ waveform.unsqueeze(1), waveform_lengths
436
+ )
437
+
438
+ content, content_mask, global_duration_pred, _, _ = \
439
+ self.encode_content_with_instruction(
440
+ content, task, device, instruction, instruction_lengths
441
+ )
442
+ global_duration_loss = self.get_global_duration_loss(
443
+ global_duration_pred, latent_mask, reduce=loss_reduce
444
+ )
445
+
446
+ if self.training and self.classifier_free_guidance:
447
+ mask_indices = [
448
+ k for k in range(len(waveform))
449
+ if random.random() < self.cfg_drop_ratio
450
+ ]
451
+ if len(mask_indices) > 0:
452
+ content[mask_indices] = 0
453
+
454
+ noisy_latent, target, timesteps = self.get_input_target_and_timesteps(
455
+ latent, training=self.training
456
+ )
457
+
458
+ pred: torch.Tensor = self.backbone(
459
+ x=noisy_latent,
460
+ timesteps=timesteps,
461
+ context=content,
462
+ x_mask=latent_mask,
463
+ context_mask=content_mask
464
+ )
465
+
466
+ pred = pred.transpose(1, self.autoencoder.time_dim)
467
+ target = target.transpose(1, self.autoencoder.time_dim)
468
+ diff_loss = self.loss_with_snr(pred, target, timesteps, latent_mask)
469
+
470
+ return {
471
+ "diff_loss": diff_loss,
472
+ "global_duration_loss": global_duration_loss,
473
+ }
474
+
475
+ @torch.no_grad()
476
+ def inference(
477
+ self,
478
+ content: list[Any],
479
+ condition: list[Any],
480
+ task: list[str],
481
+ is_time_aligned: Sequence[bool],
482
+ instruction: torch.Tensor,
483
+ instruction_lengths: Sequence[int],
484
+ scheduler: SchedulerMixin,
485
+ num_steps: int = 50,
486
+ guidance_scale: float = 3.0,
487
+ guidance_rescale: float = 0.0,
488
+ disable_progress: bool = True,
489
+ use_gt_duration: bool = False,
490
+ **kwargs
491
+ ):
492
+ device = self.dummy_param.device
493
+ classifier_free_guidance = guidance_scale > 1.0
494
+
495
+ (
496
+ content,
497
+ content_mask,
498
+ global_duration_pred,
499
+ local_duration_pred,
500
+ _,
501
+ ) = self.encode_content_with_instruction(
502
+ content, task, device, instruction, instruction_lengths
503
+ )
504
+
505
+ if use_gt_duration:
506
+ raise NotImplementedError(
507
+ "Using ground truth global duration only is not implemented yet"
508
+ )
509
+
510
+ # prepare global duration
511
+ global_duration = self.prepare_global_duration(
512
+ global_duration_pred,
513
+ local_duration_pred,
514
+ is_time_aligned,
515
+ use_local=False
516
+ )
517
+ latent_length = torch.round(global_duration * self.latent_token_rate)
518
+ latent_mask = create_mask_from_length(latent_length).to(device)
519
+ max_latent_length = latent_mask.sum(1).max().item()
520
+
521
+ # prepare latent and noise
522
+ if classifier_free_guidance:
523
+ uncond_content = torch.zeros_like(content)
524
+ uncond_content_mask = content_mask.detach().clone()
525
+ context = torch.cat([uncond_content, content])
526
+ context_mask = torch.cat([uncond_content_mask, content_mask])
527
+ else:
528
+ context = content
529
+ context_mask = content_mask
530
+
531
+ batch_size = content.size(0)
532
+ latent_shape = tuple(
533
+ max_latent_length if dim is None else dim
534
+ for dim in self.autoencoder.latent_shape
535
+ )
536
+ latent = self.prepare_latent(
537
+ batch_size, scheduler, latent_shape, content.dtype, device
538
+ )
539
+
540
+ scheduler.set_timesteps(num_steps, device=device)
541
+ latent = self.iterative_denoise(
542
+ latent=latent,
543
+ scheduler=scheduler,
544
+ verbose=not disable_progress,
545
+ cfg=classifier_free_guidance,
546
+ cfg_scale=guidance_scale,
547
+ cfg_rescale=guidance_rescale,
548
+ backbone_input={
549
+ "x_mask": latent_mask,
550
+ "context": context,
551
+ "context_mask": context_mask,
552
+ }
553
+ )
554
+
555
+ waveform = self.autoencoder.decode(latent)
556
+
557
+ return waveform
558
+
559
+
560
+ class DummyContentAudioDiffusion(CrossAttentionAudioDiffusion):
561
+ def __init__(
562
+ self,
563
+ autoencoder: AutoEncoderBase,
564
+ content_encoder: ContentEncoder,
565
+ content_adapter: ContentAdapterBase,
566
+ backbone: nn.Module,
567
+ content_dim: int,
568
+ frame_resolution: float,
569
+ duration_offset: float = 1.0,
570
+ noise_scheduler_name: str = "stabilityai/stable-diffusion-2-1",
571
+ snr_gamma: float = None,
572
+ cfg_drop_ratio: float = 0.2,
573
+ ):
574
+ """
575
+ Args:
576
+ autoencoder:
577
+ Pretrained audio autoencoder that encodes raw waveforms into latent
578
+ space and decodes latents back to waveforms.
579
+ content_encoder:
580
+ Module that produces content embeddings (e.g., from text, MIDI, or
581
+ other modalities) used to guide the diffusion.
582
+ content_adapter (ContentAdapterBase):
583
+ Adapter module that fuses task instruction embeddings and content embeddings,
584
+ and performs duration prediction for time-aligned tasks.
585
+ backbone:
586
+ U‑Net or Transformer backbone that performs the core denoising
587
+ operations in latent space.
588
+ content_dim:
589
+ Dimension of the content embeddings produced by the `content_encoder`
590
+ and `content_adapter`.
591
+ frame_resolution:
592
+ Time resolution, in seconds, of each content frame when predicting
593
+ duration alignment. Used when calculating duration loss.
594
+ duration_offset:
595
+ A small positive offset (frame number) added to predicted durations
596
+ to ensure numerical stability of log-scaled duration prediction.
597
+ noise_scheduler_name:
598
+ Identifier of the pretrained noise scheduler to use.
599
+ snr_gamma:
600
+ Clipping value in min-SNR diffusion loss weighting strategy.
601
+ cfg_drop_ratio:
602
+ Probability of dropping the content conditioning during training
603
+ to support CFG.
604
+ """
605
+ super().__init__(
606
+ autoencoder=autoencoder,
607
+ content_encoder=content_encoder,
608
+ content_adapter=content_adapter,
609
+ backbone=backbone,
610
+ duration_offset=duration_offset,
611
+ noise_scheduler_name=noise_scheduler_name,
612
+ snr_gamma=snr_gamma,
613
+ cfg_drop_ratio=cfg_drop_ratio,
614
+ )
615
+ self.frame_resolution = frame_resolution
616
+ self.dummy_nta_embed = nn.Parameter(torch.zeros(content_dim))
617
+ self.dummy_ta_embed = nn.Parameter(torch.zeros(content_dim))
618
+
619
+ def get_backbone_input(
620
+ self,
621
+ target_length: int,
622
+ content: torch.Tensor,
623
+ content_mask: torch.Tensor,
624
+ time_aligned_content: torch.Tensor,
625
+ length_aligned_content: torch.Tensor,
626
+ is_time_aligned: torch.Tensor,
627
+ ):
628
+ # TODO compatility for 2D spectrogram VAE
629
+ time_aligned_content = trim_or_pad_length(
630
+ time_aligned_content, target_length, 1
631
+ )
632
+ length_aligned_content = trim_or_pad_length(
633
+ length_aligned_content, target_length, 1
634
+ )
635
+ # time_aligned_content: from monotonic aligned input, without frame expansion (phoneme)
636
+ # length_aligned_content: from aligned input (f0/energy)
637
+ time_aligned_content = time_aligned_content + length_aligned_content
638
+ time_aligned_content[~is_time_aligned] = self.dummy_ta_embed.to(
639
+ time_aligned_content.dtype
640
+ )
641
+
642
+ context = content
643
+ context[is_time_aligned] = self.dummy_nta_embed.to(context.dtype)
644
+ # only use the first dummy non time aligned embedding
645
+ context_mask = content_mask.detach().clone()
646
+ context_mask[is_time_aligned, 1:] = False
647
+
648
+ # truncate dummy non time aligned context
649
+ if is_time_aligned.sum().item() < content.size(0):
650
+ trunc_nta_length = content_mask[~is_time_aligned].sum(1).max()
651
+ else:
652
+ trunc_nta_length = content.size(1)
653
+ context = context[:, :trunc_nta_length]
654
+ context_mask = context_mask[:, :trunc_nta_length]
655
+
656
+ return context, context_mask, time_aligned_content
657
+
658
+ def forward(
659
+ self,
660
+ content: list[Any],
661
+ task: list[str],
662
+ is_time_aligned: Sequence[bool],
663
+ duration: Sequence[float],
664
+ waveform: torch.Tensor,
665
+ waveform_lengths: torch.Tensor,
666
+ instruction: torch.Tensor,
667
+ instruction_lengths: Sequence[int],
668
+ loss_reduce: bool = True,
669
+ **kwargs
670
+ ):
671
+ device = self.dummy_param.device
672
+ loss_reduce = self.training or (loss_reduce and not self.training)
673
+
674
+ self.autoencoder.eval()
675
+ with torch.no_grad():
676
+ latent, latent_mask = self.autoencoder.encode(
677
+ waveform.unsqueeze(1), waveform_lengths
678
+ )
679
+
680
+ (
681
+ content, content_mask, global_duration_pred, local_duration_pred,
682
+ length_aligned_content
683
+ ) = self.encode_content_with_instruction(
684
+ content, task, device, instruction, instruction_lengths
685
+ )
686
+
687
+ # truncate unused non time aligned duration prediction
688
+ if is_time_aligned.sum() > 0:
689
+ trunc_ta_length = content_mask[is_time_aligned].sum(1).max()
690
+ else:
691
+ trunc_ta_length = content.size(1)
692
+
693
+ # duration loss
694
+ local_duration_pred = local_duration_pred[:, :trunc_ta_length]
695
+ ta_content_mask = content_mask[:, :trunc_ta_length]
696
+ local_duration_loss = self.get_local_duration_loss(
697
+ duration,
698
+ local_duration_pred,
699
+ ta_content_mask,
700
+ is_time_aligned,
701
+ reduce=loss_reduce
702
+ )
703
+ global_duration_loss = self.get_global_duration_loss(
704
+ global_duration_pred, latent_mask, reduce=loss_reduce
705
+ )
706
+
707
+ # --------------------------------------------------------------------
708
+ # prepare latent and diffusion-related noise
709
+ # --------------------------------------------------------------------
710
+ noisy_latent, target, timesteps = self.get_input_target_and_timesteps(
711
+ latent, training=self.training
712
+ )
713
+
714
+ # --------------------------------------------------------------------
715
+ # duration adapter
716
+ # --------------------------------------------------------------------
717
+ if is_time_aligned.sum() == 0 and \
718
+ duration.size(1) < content_mask.size(1):
719
+ # for non time-aligned tasks like TTA, `duration` is dummy one
720
+ duration = F.pad(
721
+ duration, (0, content_mask.size(1) - duration.size(1))
722
+ )
723
+ time_aligned_content, _ = self.expand_by_duration(
724
+ x=content[:, :trunc_ta_length],
725
+ content_mask=ta_content_mask,
726
+ local_duration=duration,
727
+ )
728
+
729
+ # --------------------------------------------------------------------
730
+ # prepare input to the backbone
731
+ # --------------------------------------------------------------------
732
+ # TODO compatility for 2D spectrogram VAE
733
+ latent_length = noisy_latent.size(self.autoencoder.time_dim)
734
+ context, context_mask, time_aligned_content = self.get_backbone_input(
735
+ latent_length, content, content_mask, time_aligned_content,
736
+ length_aligned_content, is_time_aligned
737
+ )
738
+
739
+ # --------------------------------------------------------------------
740
+ # classifier free guidance
741
+ # --------------------------------------------------------------------
742
+ if self.training and self.classifier_free_guidance:
743
+ mask_indices = [
744
+ k for k in range(len(waveform))
745
+ if random.random() < self.cfg_drop_ratio
746
+ ]
747
+ if len(mask_indices) > 0:
748
+ context[mask_indices] = 0
749
+ time_aligned_content[mask_indices] = 0
750
+
751
+ pred: torch.Tensor = self.backbone(
752
+ x=noisy_latent,
753
+ x_mask=latent_mask,
754
+ timesteps=timesteps,
755
+ context=context,
756
+ context_mask=context_mask,
757
+ time_aligned_context=time_aligned_content,
758
+ )
759
+ pred = pred.transpose(1, self.autoencoder.time_dim)
760
+ target = target.transpose(1, self.autoencoder.time_dim)
761
+ diff_loss = self.loss_with_snr(
762
+ pred, target, timesteps, latent_mask, reduce=loss_reduce
763
+ )
764
+ return {
765
+ "diff_loss": diff_loss,
766
+ "local_duration_loss": local_duration_loss,
767
+ "global_duration_loss": global_duration_loss
768
+ }
769
+
770
+ @torch.no_grad()
771
+ def inference(
772
+ self,
773
+ content: list[Any],
774
+ condition: list[Any],
775
+ task: list[str],
776
+ is_time_aligned: list[bool],
777
+ instruction: torch.Tensor,
778
+ instruction_lengths: Sequence[int],
779
+ scheduler: SchedulerMixin,
780
+ num_steps: int = 20,
781
+ guidance_scale: float = 3.0,
782
+ guidance_rescale: float = 0.0,
783
+ disable_progress: bool = True,
784
+ use_gt_duration: bool = False,
785
+ **kwargs
786
+ ):
787
+ device = self.dummy_param.device
788
+ classifier_free_guidance = guidance_scale > 1.0
789
+
790
+ (
791
+ content, content_mask, global_duration_pred, local_duration_pred,
792
+ length_aligned_content
793
+ ) = self.encode_content_with_instruction(
794
+ content, task, device, instruction, instruction_lengths
795
+ )
796
+
797
+ batch_size = content.size(0)
798
+
799
+ # truncate dummy time aligned duration prediction
800
+ is_time_aligned = torch.as_tensor(is_time_aligned)
801
+ if is_time_aligned.sum() > 0:
802
+ trunc_ta_length = content_mask[is_time_aligned].sum(1).max()
803
+ else:
804
+ trunc_ta_length = content.size(1)
805
+
806
+ # prepare local duration
807
+ local_duration = self.prepare_local_duration(
808
+ local_duration_pred, content_mask
809
+ )
810
+ local_duration = local_duration[:, :trunc_ta_length]
811
+ # use ground truth duration
812
+ if use_gt_duration and "duration" in kwargs:
813
+ local_duration = torch.as_tensor(kwargs["duration"]).to(device)
814
+
815
+ # prepare global duration
816
+ global_duration = self.prepare_global_duration(
817
+ global_duration_pred, local_duration, is_time_aligned
818
+ )
819
+
820
+ # --------------------------------------------------------------------
821
+ # duration adapter
822
+ # --------------------------------------------------------------------
823
+ time_aligned_content, latent_mask = self.expand_by_duration(
824
+ x=content[:, :trunc_ta_length],
825
+ content_mask=content_mask[:, :trunc_ta_length],
826
+ local_duration=local_duration,
827
+ global_duration=global_duration,
828
+ )
829
+
830
+ context, context_mask, time_aligned_content = self.get_backbone_input(
831
+ target_length=time_aligned_content.size(1),
832
+ content=content,
833
+ content_mask=content_mask,
834
+ time_aligned_content=time_aligned_content,
835
+ length_aligned_content=length_aligned_content,
836
+ is_time_aligned=is_time_aligned
837
+ )
838
+
839
+ # --------------------------------------------------------------------
840
+ # prepare unconditional input
841
+ # --------------------------------------------------------------------
842
+ if classifier_free_guidance:
843
+ uncond_time_aligned_content = torch.zeros_like(
844
+ time_aligned_content
845
+ )
846
+ uncond_context = torch.zeros_like(context)
847
+ uncond_context_mask = context_mask.detach().clone()
848
+ time_aligned_content = torch.cat([
849
+ uncond_time_aligned_content, time_aligned_content
850
+ ])
851
+ context = torch.cat([uncond_context, context])
852
+ context_mask = torch.cat([uncond_context_mask, context_mask])
853
+ latent_mask = torch.cat([
854
+ latent_mask, latent_mask.detach().clone()
855
+ ])
856
+
857
+ # --------------------------------------------------------------------
858
+ # prepare input to the backbone
859
+ # --------------------------------------------------------------------
860
+ latent_length = latent_mask.sum(1).max().item()
861
+ latent_shape = tuple(
862
+ latent_length if dim is None else dim
863
+ for dim in self.autoencoder.latent_shape
864
+ )
865
+ latent = self.prepare_latent(
866
+ batch_size, scheduler, latent_shape, content.dtype, device
867
+ )
868
+
869
+ scheduler.set_timesteps(num_steps, device=device)
870
+ latent = self.iterative_denoise(
871
+ latent=latent,
872
+ scheduler=scheduler,
873
+ verbose=not disable_progress,
874
+ cfg=classifier_free_guidance,
875
+ cfg_scale=guidance_scale,
876
+ cfg_rescale=guidance_rescale,
877
+ backbone_input={
878
+ "x_mask": latent_mask,
879
+ "context": context,
880
+ "context_mask": context_mask,
881
+ "time_aligned_context": time_aligned_content,
882
+ }
883
+ )
884
+ # TODO variable length decoding, using `latent_mask`
885
+ waveform = self.autoencoder.decode(latent)
886
+ return waveform
887
+
888
+
889
+ class DoubleContentAudioDiffusion(DummyContentAudioDiffusion):
890
+ def get_backbone_input(
891
+ self,
892
+ target_length: int,
893
+ content: torch.Tensor,
894
+ content_mask: torch.Tensor,
895
+ time_aligned_content: torch.Tensor,
896
+ length_aligned_content: torch.Tensor,
897
+ is_time_aligned: torch.Tensor,
898
+ ):
899
+ time_aligned_content = trim_or_pad_length(
900
+ time_aligned_content, target_length, 1
901
+ )
902
+ context_length = min(content.size(1), time_aligned_content.size(1))
903
+ time_aligned_content[~is_time_aligned, :context_length] = content[
904
+ ~is_time_aligned, :context_length]
905
+ length_aligned_content = trim_or_pad_length(
906
+ length_aligned_content, target_length, 1
907
+ )
908
+ time_aligned_content = time_aligned_content + length_aligned_content
909
+
910
+ context = content
911
+ context_mask = content_mask.detach().clone()
912
+
913
+ return context, context_mask, time_aligned_content
914
+
915
+
916
+ class HybridContentAudioDiffusion(DummyContentAudioDiffusion):
917
+ def get_backbone_input(
918
+ self,
919
+ target_length: int,
920
+ content: torch.Tensor,
921
+ content_mask: torch.Tensor,
922
+ time_aligned_content: torch.Tensor,
923
+ length_aligned_content: torch.Tensor,
924
+ is_time_aligned: torch.Tensor,
925
+ ):
926
+ # TODO compatility for 2D spectrogram VAE
927
+ time_aligned_content = trim_or_pad_length(
928
+ time_aligned_content, target_length, 1
929
+ )
930
+ length_aligned_content = trim_or_pad_length(
931
+ length_aligned_content, target_length, 1
932
+ )
933
+ # time_aligned_content: from monotonic aligned input, without frame expansion (phoneme)
934
+ # length_aligned_content: from aligned input (f0/energy)
935
+ time_aligned_content = time_aligned_content + length_aligned_content
936
+ time_aligned_content[~is_time_aligned] = self.dummy_ta_embed.to(
937
+ time_aligned_content.dtype
938
+ )
939
+
940
+ context = content
941
+ context_mask = content_mask.detach().clone()
942
+
943
+ return context, context_mask, time_aligned_content
models/dit/attention.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.utils.checkpoint
5
+ import einops
6
+ from einops import rearrange, repeat
7
+ from inspect import isfunction
8
+ from .rotary import RotaryEmbedding
9
+ from .modules import RMSNorm
10
+
11
+ if hasattr(nn.functional, 'scaled_dot_product_attention'):
12
+ ATTENTION_MODE = 'flash'
13
+ else:
14
+ ATTENTION_MODE = 'math'
15
+ print(f'attention mode is {ATTENTION_MODE}')
16
+
17
+
18
+ def add_mask(sim, mask):
19
+ b, ndim = sim.shape[0], mask.ndim
20
+ if ndim == 3:
21
+ mask = rearrange(mask, "b n m -> b 1 n m")
22
+ if ndim == 2:
23
+ mask = repeat(mask, "n m -> b 1 n m", b=b)
24
+ max_neg_value = -torch.finfo(sim.dtype).max
25
+ sim = sim.masked_fill(~mask, max_neg_value)
26
+ return sim
27
+
28
+
29
+ def create_mask(q_shape, k_shape, device, q_mask=None, k_mask=None):
30
+ def default(val, d):
31
+ return val if val is not None else (d() if isfunction(d) else d)
32
+
33
+ b, i, j, device = q_shape[0], q_shape[-2], k_shape[-2], device
34
+ q_mask = default(
35
+ q_mask, torch.ones((b, i), device=device, dtype=torch.bool)
36
+ )
37
+ k_mask = default(
38
+ k_mask, torch.ones((b, j), device=device, dtype=torch.bool)
39
+ )
40
+ attn_mask = rearrange(q_mask, 'b i -> b 1 i 1'
41
+ ) * rearrange(k_mask, 'b j -> b 1 1 j')
42
+ return attn_mask
43
+
44
+
45
+ class Attention(nn.Module):
46
+ def __init__(
47
+ self,
48
+ dim,
49
+ context_dim=None,
50
+ num_heads=8,
51
+ qkv_bias=False,
52
+ qk_scale=None,
53
+ qk_norm=None,
54
+ attn_drop=0.,
55
+ proj_drop=0.,
56
+ rope_mode='none'
57
+ ):
58
+ super().__init__()
59
+ self.num_heads = num_heads
60
+ head_dim = dim // num_heads
61
+ self.scale = qk_scale or head_dim**-0.5
62
+
63
+ if context_dim is None:
64
+ self.cross_attn = False
65
+ else:
66
+ self.cross_attn = True
67
+
68
+ context_dim = dim if context_dim is None else context_dim
69
+
70
+ self.to_q = nn.Linear(dim, dim, bias=qkv_bias)
71
+ self.to_k = nn.Linear(context_dim, dim, bias=qkv_bias)
72
+ self.to_v = nn.Linear(context_dim, dim, bias=qkv_bias)
73
+
74
+ if qk_norm is None:
75
+ self.norm_q = nn.Identity()
76
+ self.norm_k = nn.Identity()
77
+ elif qk_norm == 'layernorm':
78
+ self.norm_q = nn.LayerNorm(head_dim)
79
+ self.norm_k = nn.LayerNorm(head_dim)
80
+ elif qk_norm == 'rmsnorm':
81
+ self.norm_q = RMSNorm(head_dim)
82
+ self.norm_k = RMSNorm(head_dim)
83
+ else:
84
+ raise NotImplementedError
85
+
86
+ self.attn_drop_p = attn_drop
87
+ self.attn_drop = nn.Dropout(attn_drop)
88
+ self.proj = nn.Linear(dim, dim)
89
+ self.proj_drop = nn.Dropout(proj_drop)
90
+
91
+ if self.cross_attn:
92
+ assert rope_mode == 'none'
93
+ self.rope_mode = rope_mode
94
+ if self.rope_mode == 'shared' or self.rope_mode == 'x_only':
95
+ self.rotary = RotaryEmbedding(dim=head_dim)
96
+ elif self.rope_mode == 'dual':
97
+ self.rotary_x = RotaryEmbedding(dim=head_dim)
98
+ self.rotary_c = RotaryEmbedding(dim=head_dim)
99
+
100
+ def _rotary(self, q, k, extras):
101
+ if self.rope_mode == 'shared':
102
+ q, k = self.rotary(q=q, k=k)
103
+ elif self.rope_mode == 'x_only':
104
+ q_x, k_x = self.rotary(
105
+ q=q[:, :, extras:, :], k=k[:, :, extras:, :]
106
+ )
107
+ q_c, k_c = q[:, :, :extras, :], k[:, :, :extras, :]
108
+ q = torch.cat((q_c, q_x), dim=2)
109
+ k = torch.cat((k_c, k_x), dim=2)
110
+ elif self.rope_mode == 'dual':
111
+ q_x, k_x = self.rotary_x(
112
+ q=q[:, :, extras:, :], k=k[:, :, extras:, :]
113
+ )
114
+ q_c, k_c = self.rotary_c(
115
+ q=q[:, :, :extras, :], k=k[:, :, :extras, :]
116
+ )
117
+ q = torch.cat((q_c, q_x), dim=2)
118
+ k = torch.cat((k_c, k_x), dim=2)
119
+ elif self.rope_mode == 'none':
120
+ pass
121
+ else:
122
+ raise NotImplementedError
123
+ return q, k
124
+
125
+ def _attn(self, q, k, v, mask_binary):
126
+ if ATTENTION_MODE == 'flash':
127
+ x = F.scaled_dot_product_attention(
128
+ q, k, v, dropout_p=self.attn_drop_p, attn_mask=mask_binary
129
+ )
130
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
131
+ elif ATTENTION_MODE == 'math':
132
+ attn = (q @ k.transpose(-2, -1)) * self.scale
133
+ attn = add_mask(
134
+ attn, mask_binary
135
+ ) if mask_binary is not None else attn
136
+ attn = attn.softmax(dim=-1)
137
+ attn = self.attn_drop(attn)
138
+ x = (attn @ v).transpose(1, 2)
139
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
140
+ else:
141
+ raise NotImplementedError
142
+ return x
143
+
144
+ def forward(self, x, context=None, context_mask=None, extras=0):
145
+ B, L, C = x.shape
146
+ if context is None:
147
+ context = x
148
+
149
+ q = self.to_q(x)
150
+ k = self.to_k(context)
151
+ v = self.to_v(context)
152
+
153
+ if context_mask is not None:
154
+ mask_binary = create_mask(
155
+ x.shape, context.shape, x.device, None, context_mask
156
+ )
157
+ else:
158
+ mask_binary = None
159
+
160
+ q = einops.rearrange(q, 'B L (H D) -> B H L D', H=self.num_heads)
161
+ k = einops.rearrange(k, 'B L (H D) -> B H L D', H=self.num_heads)
162
+ v = einops.rearrange(v, 'B L (H D) -> B H L D', H=self.num_heads)
163
+
164
+ q = self.norm_q(q)
165
+ k = self.norm_k(k)
166
+
167
+ q, k = self._rotary(q, k, extras)
168
+
169
+ x = self._attn(q, k, v, mask_binary)
170
+
171
+ x = self.proj(x)
172
+ x = self.proj_drop(x)
173
+ return x
174
+
175
+
176
+ class JointAttention(nn.Module):
177
+ def __init__(
178
+ self,
179
+ dim,
180
+ num_heads=8,
181
+ qkv_bias=False,
182
+ qk_scale=None,
183
+ qk_norm=None,
184
+ attn_drop=0.,
185
+ proj_drop=0.,
186
+ rope_mode='none'
187
+ ):
188
+ super().__init__()
189
+ self.num_heads = num_heads
190
+ head_dim = dim // num_heads
191
+ self.scale = qk_scale or head_dim**-0.5
192
+
193
+ self.to_qx, self.to_kx, self.to_vx = self._make_qkv_layers(
194
+ dim, qkv_bias
195
+ )
196
+ self.to_qc, self.to_kc, self.to_vc = self._make_qkv_layers(
197
+ dim, qkv_bias
198
+ )
199
+
200
+ self.norm_qx, self.norm_kx = self._make_norm_layers(qk_norm, head_dim)
201
+ self.norm_qc, self.norm_kc = self._make_norm_layers(qk_norm, head_dim)
202
+
203
+ self.attn_drop_p = attn_drop
204
+ self.attn_drop = nn.Dropout(attn_drop)
205
+
206
+ self.proj_x = nn.Linear(dim, dim)
207
+ self.proj_drop_x = nn.Dropout(proj_drop)
208
+
209
+ self.proj_c = nn.Linear(dim, dim)
210
+ self.proj_drop_c = nn.Dropout(proj_drop)
211
+
212
+ self.rope_mode = rope_mode
213
+ if self.rope_mode == 'shared' or self.rope_mode == 'x_only':
214
+ self.rotary = RotaryEmbedding(dim=head_dim)
215
+ elif self.rope_mode == 'dual':
216
+ self.rotary_x = RotaryEmbedding(dim=head_dim)
217
+ self.rotary_c = RotaryEmbedding(dim=head_dim)
218
+
219
+ def _make_qkv_layers(self, dim, qkv_bias):
220
+ return (
221
+ nn.Linear(dim, dim,
222
+ bias=qkv_bias), nn.Linear(dim, dim, bias=qkv_bias),
223
+ nn.Linear(dim, dim, bias=qkv_bias)
224
+ )
225
+
226
+ def _make_norm_layers(self, qk_norm, head_dim):
227
+ if qk_norm is None:
228
+ norm_q = nn.Identity()
229
+ norm_k = nn.Identity()
230
+ elif qk_norm == 'layernorm':
231
+ norm_q = nn.LayerNorm(head_dim)
232
+ norm_k = nn.LayerNorm(head_dim)
233
+ elif qk_norm == 'rmsnorm':
234
+ norm_q = RMSNorm(head_dim)
235
+ norm_k = RMSNorm(head_dim)
236
+ else:
237
+ raise NotImplementedError
238
+ return norm_q, norm_k
239
+
240
+ def _rotary(self, q, k, extras):
241
+ if self.rope_mode == 'shared':
242
+ q, k = self.rotary(q=q, k=k)
243
+ elif self.rope_mode == 'x_only':
244
+ q_x, k_x = self.rotary(
245
+ q=q[:, :, extras:, :], k=k[:, :, extras:, :]
246
+ )
247
+ q_c, k_c = q[:, :, :extras, :], k[:, :, :extras, :]
248
+ q = torch.cat((q_c, q_x), dim=2)
249
+ k = torch.cat((k_c, k_x), dim=2)
250
+ elif self.rope_mode == 'dual':
251
+ q_x, k_x = self.rotary_x(
252
+ q=q[:, :, extras:, :], k=k[:, :, extras:, :]
253
+ )
254
+ q_c, k_c = self.rotary_c(
255
+ q=q[:, :, :extras, :], k=k[:, :, :extras, :]
256
+ )
257
+ q = torch.cat((q_c, q_x), dim=2)
258
+ k = torch.cat((k_c, k_x), dim=2)
259
+ elif self.rope_mode == 'none':
260
+ pass
261
+ else:
262
+ raise NotImplementedError
263
+ return q, k
264
+
265
+ def _attn(self, q, k, v, mask_binary):
266
+ if ATTENTION_MODE == 'flash':
267
+ x = F.scaled_dot_product_attention(
268
+ q, k, v, dropout_p=self.attn_drop_p, attn_mask=mask_binary
269
+ )
270
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
271
+ elif ATTENTION_MODE == 'math':
272
+ attn = (q @ k.transpose(-2, -1)) * self.scale
273
+ attn = add_mask(
274
+ attn, mask_binary
275
+ ) if mask_binary is not None else attn
276
+ attn = attn.softmax(dim=-1)
277
+ attn = self.attn_drop(attn)
278
+ x = (attn @ v).transpose(1, 2)
279
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
280
+ else:
281
+ raise NotImplementedError
282
+ return x
283
+
284
+ def _cat_mask(self, x, context, x_mask=None, context_mask=None):
285
+ B = x.shape[0]
286
+ if x_mask is None:
287
+ x_mask = torch.ones(B, x.shape[-2], device=x.device).bool()
288
+ if context_mask is None:
289
+ context_mask = torch.ones(
290
+ B, context.shape[-2], device=context.device
291
+ ).bool()
292
+ mask = torch.cat([context_mask, x_mask], dim=1)
293
+ return mask
294
+
295
+ def forward(self, x, context, x_mask=None, context_mask=None, extras=0):
296
+ B, Lx, C = x.shape
297
+ _, Lc, _ = context.shape
298
+ if x_mask is not None or context_mask is not None:
299
+ mask = self._cat_mask(
300
+ x, context, x_mask=x_mask, context_mask=context_mask
301
+ )
302
+ shape = [B, Lx + Lc, C]
303
+ mask_binary = create_mask(
304
+ q_shape=shape,
305
+ k_shape=shape,
306
+ device=x.device,
307
+ q_mask=None,
308
+ k_mask=mask
309
+ )
310
+ else:
311
+ mask_binary = None
312
+
313
+ qx, kx, vx = self.to_qx(x), self.to_kx(x), self.to_vx(x)
314
+ qc, kc, vc = self.to_qc(context), self.to_kc(context
315
+ ), self.to_vc(context)
316
+
317
+ qx, kx, vx = map(
318
+ lambda t: einops.
319
+ rearrange(t, 'B L (H D) -> B H L D', H=self.num_heads),
320
+ [qx, kx, vx]
321
+ )
322
+ qc, kc, vc = map(
323
+ lambda t: einops.
324
+ rearrange(t, 'B L (H D) -> B H L D', H=self.num_heads),
325
+ [qc, kc, vc]
326
+ )
327
+
328
+ qx, kx = self.norm_qx(qx), self.norm_kx(kx)
329
+ qc, kc = self.norm_qc(qc), self.norm_kc(kc)
330
+
331
+ q, k, v = (
332
+ torch.cat([qc, qx],
333
+ dim=2), torch.cat([kc, kx],
334
+ dim=2), torch.cat([vc, vx], dim=2)
335
+ )
336
+
337
+ q, k = self._rotary(q, k, extras)
338
+
339
+ x = self._attn(q, k, v, mask_binary)
340
+
341
+ context, x = x[:, :Lc, :], x[:, Lc:, :]
342
+
343
+ x = self.proj_x(x)
344
+ x = self.proj_drop_x(x)
345
+
346
+ context = self.proj_c(context)
347
+ context = self.proj_drop_c(context)
348
+
349
+ return x, context
models/dit/audio_diffsingernet_dit.py ADDED
@@ -0,0 +1,520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.checkpoint import checkpoint
4
+
5
+ from .mask_dit import DiTBlock, FinalBlock, UDiT
6
+ from .modules import (
7
+ film_modulate,
8
+ PatchEmbed,
9
+ PE_wrapper,
10
+ TimestepEmbedder,
11
+ RMSNorm,
12
+ )
13
+
14
+
15
+ class AudioDiTBlock(DiTBlock):
16
+ """
17
+ A modified DiT block with time_aligned_context add to latent.
18
+ """
19
+ def __init__(
20
+ self,
21
+ dim,
22
+ time_aligned_context_dim,
23
+ dilation,
24
+ context_dim=None,
25
+ num_heads=8,
26
+ mlp_ratio=4.,
27
+ qkv_bias=False,
28
+ qk_scale=None,
29
+ qk_norm=None,
30
+ act_layer='gelu',
31
+ norm_layer=nn.LayerNorm,
32
+ time_fusion='none',
33
+ ada_sola_rank=None,
34
+ ada_sola_alpha=None,
35
+ skip=False,
36
+ skip_norm=False,
37
+ rope_mode='none',
38
+ context_norm=False,
39
+ use_checkpoint=False
40
+ ):
41
+ super().__init__(
42
+ dim=dim,
43
+ context_dim=context_dim,
44
+ num_heads=num_heads,
45
+ mlp_ratio=mlp_ratio,
46
+ qkv_bias=qkv_bias,
47
+ qk_scale=qk_scale,
48
+ qk_norm=qk_norm,
49
+ act_layer=act_layer,
50
+ norm_layer=norm_layer,
51
+ time_fusion=time_fusion,
52
+ ada_sola_rank=ada_sola_rank,
53
+ ada_sola_alpha=ada_sola_alpha,
54
+ skip=skip,
55
+ skip_norm=skip_norm,
56
+ rope_mode=rope_mode,
57
+ context_norm=context_norm,
58
+ use_checkpoint=use_checkpoint
59
+ )
60
+ # time-aligned context projection
61
+ self.ta_context_projection = nn.Linear(
62
+ time_aligned_context_dim, 2 * dim
63
+ )
64
+ self.dilated_conv = nn.Conv1d(
65
+ dim, 2 * dim, kernel_size=3, padding=dilation, dilation=dilation
66
+ )
67
+
68
+ def forward(
69
+ self,
70
+ x,
71
+ time_aligned_context,
72
+ time_token=None,
73
+ time_ada=None,
74
+ skip=None,
75
+ context=None,
76
+ x_mask=None,
77
+ context_mask=None,
78
+ extras=None
79
+ ):
80
+ if self.use_checkpoint:
81
+ return checkpoint(
82
+ self._forward,
83
+ x,
84
+ time_aligned_context,
85
+ time_token,
86
+ time_ada,
87
+ skip,
88
+ context,
89
+ x_mask,
90
+ context_mask,
91
+ extras,
92
+ use_reentrant=False
93
+ )
94
+ else:
95
+ return self._forward(
96
+ x,
97
+ time_aligned_context,
98
+ time_token,
99
+ time_ada,
100
+ skip,
101
+ context,
102
+ x_mask,
103
+ context_mask,
104
+ extras,
105
+ )
106
+
107
+ def _forward(
108
+ self,
109
+ x,
110
+ time_aligned_context,
111
+ time_token=None,
112
+ time_ada=None,
113
+ skip=None,
114
+ context=None,
115
+ x_mask=None,
116
+ context_mask=None,
117
+ extras=None
118
+ ):
119
+ B, T, C = x.shape
120
+ if self.skip_linear is not None:
121
+ assert skip is not None
122
+ cat = torch.cat([x, skip], dim=-1)
123
+ cat = self.skip_norm(cat)
124
+ x = self.skip_linear(cat)
125
+
126
+ if self.use_adanorm:
127
+ time_ada = self.adaln(time_token, time_ada)
128
+ (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp,
129
+ gate_mlp) = time_ada.chunk(6, dim=1)
130
+
131
+ # self attention
132
+ if self.use_adanorm:
133
+ x_norm = film_modulate(
134
+ self.norm1(x), shift=shift_msa, scale=scale_msa
135
+ )
136
+ x = x + (1-gate_msa) * self.attn(
137
+ x_norm, context=None, context_mask=x_mask, extras=extras
138
+ )
139
+ else:
140
+ # TODO diffusion timestep input is not fused here
141
+ x = x + self.attn(
142
+ self.norm1(x),
143
+ context=None,
144
+ context_mask=x_mask,
145
+ extras=extras
146
+ )
147
+
148
+ # time-aligned context
149
+ time_aligned_context = self.ta_context_projection(time_aligned_context)
150
+ x = self.dilated_conv(x.transpose(1, 2)
151
+ ).transpose(1, 2) + time_aligned_context
152
+
153
+ gate, filter = torch.chunk(x, 2, dim=-1)
154
+ x = torch.sigmoid(gate) * torch.tanh(filter)
155
+
156
+ # cross attention
157
+ if self.use_context:
158
+ assert context is not None
159
+ x = x + self.cross_attn(
160
+ x=self.norm2(x),
161
+ context=self.norm_context(context),
162
+ context_mask=context_mask,
163
+ extras=extras
164
+ )
165
+
166
+ # mlp
167
+ if self.use_adanorm:
168
+ x_norm = film_modulate(
169
+ self.norm3(x), shift=shift_mlp, scale=scale_mlp
170
+ )
171
+ x = x + (1-gate_mlp) * self.mlp(x_norm)
172
+ else:
173
+ x = x + self.mlp(self.norm3(x))
174
+
175
+ return x
176
+
177
+
178
+ class AudioUDiT(UDiT):
179
+ def __init__(
180
+ self,
181
+ img_size=224,
182
+ patch_size=16,
183
+ in_chans=3,
184
+ input_type='2d',
185
+ out_chans=None,
186
+ embed_dim=768,
187
+ depth=12,
188
+ dilation_cycle_length=4,
189
+ num_heads=12,
190
+ mlp_ratio=4,
191
+ qkv_bias=False,
192
+ qk_scale=None,
193
+ qk_norm=None,
194
+ act_layer='gelu',
195
+ norm_layer='layernorm',
196
+ context_norm=False,
197
+ use_checkpoint=False,
198
+ time_fusion='token',
199
+ ada_sola_rank=None,
200
+ ada_sola_alpha=None,
201
+ cls_dim=None,
202
+ time_aligned_context_dim=768,
203
+ context_dim=768,
204
+ context_fusion='concat',
205
+ context_max_length=128,
206
+ context_pe_method='sinu',
207
+ pe_method='abs',
208
+ rope_mode='none',
209
+ use_conv=True,
210
+ skip=True,
211
+ skip_norm=True
212
+ ):
213
+ nn.Module.__init__(self)
214
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
215
+
216
+ # input
217
+ self.in_chans = in_chans
218
+ self.input_type = input_type
219
+ if self.input_type == '2d':
220
+ num_patches = (img_size[0] //
221
+ patch_size) * (img_size[1] // patch_size)
222
+ elif self.input_type == '1d':
223
+ num_patches = img_size // patch_size
224
+ self.patch_embed = PatchEmbed(
225
+ patch_size=patch_size,
226
+ in_chans=in_chans,
227
+ embed_dim=embed_dim,
228
+ input_type=input_type
229
+ )
230
+ out_chans = in_chans if out_chans is None else out_chans
231
+ self.out_chans = out_chans
232
+
233
+ # position embedding
234
+ self.rope = rope_mode
235
+ self.x_pe = PE_wrapper(
236
+ dim=embed_dim, method=pe_method, length=num_patches
237
+ )
238
+
239
+ # time embed
240
+ self.time_embed = TimestepEmbedder(embed_dim)
241
+ self.time_fusion = time_fusion
242
+ self.use_adanorm = False
243
+
244
+ # cls embed
245
+ if cls_dim is not None:
246
+ self.cls_embed = nn.Sequential(
247
+ nn.Linear(cls_dim, embed_dim, bias=True),
248
+ nn.SiLU(),
249
+ nn.Linear(embed_dim, embed_dim, bias=True),
250
+ )
251
+ else:
252
+ self.cls_embed = None
253
+
254
+ # time fusion
255
+ if time_fusion == 'token':
256
+ # put token at the beginning of sequence
257
+ self.extras = 2 if self.cls_embed else 1
258
+ self.time_pe = PE_wrapper(
259
+ dim=embed_dim, method='abs', length=self.extras
260
+ )
261
+ elif time_fusion in ['ada', 'ada_single', 'ada_sola', 'ada_sola_bias']:
262
+ self.use_adanorm = True
263
+ # aviod repetitive silu for each adaln block
264
+ self.time_act = nn.SiLU()
265
+ self.extras = 0
266
+ self.time_ada_final = nn.Linear(
267
+ embed_dim, 2 * embed_dim, bias=True
268
+ )
269
+ if time_fusion in ['ada_single', 'ada_sola', 'ada_sola_bias']:
270
+ # shared adaln
271
+ self.time_ada = nn.Linear(embed_dim, 6 * embed_dim, bias=True)
272
+ else:
273
+ self.time_ada = None
274
+ else:
275
+ raise NotImplementedError
276
+
277
+ # context
278
+ # use a simple projection
279
+ self.use_context = False
280
+ self.context_cross = False
281
+ self.context_max_length = context_max_length
282
+ self.context_fusion = 'none'
283
+ if context_dim is not None:
284
+ self.use_context = True
285
+ self.context_embed = nn.Sequential(
286
+ nn.Linear(context_dim, embed_dim, bias=True),
287
+ nn.SiLU(),
288
+ nn.Linear(embed_dim, embed_dim, bias=True),
289
+ )
290
+ self.context_fusion = context_fusion
291
+ if context_fusion == 'concat' or context_fusion == 'joint':
292
+ self.extras += context_max_length
293
+ self.context_pe = PE_wrapper(
294
+ dim=embed_dim,
295
+ method=context_pe_method,
296
+ length=context_max_length
297
+ )
298
+ # no cross attention layers
299
+ context_dim = None
300
+ elif context_fusion == 'cross':
301
+ self.context_pe = PE_wrapper(
302
+ dim=embed_dim,
303
+ method=context_pe_method,
304
+ length=context_max_length
305
+ )
306
+ self.context_cross = True
307
+ context_dim = embed_dim
308
+ else:
309
+ raise NotImplementedError
310
+
311
+ self.use_skip = skip
312
+
313
+ # norm layers
314
+ if norm_layer == 'layernorm':
315
+ norm_layer = nn.LayerNorm
316
+ elif norm_layer == 'rmsnorm':
317
+ norm_layer = RMSNorm
318
+ else:
319
+ raise NotImplementedError
320
+
321
+ self.in_blocks = nn.ModuleList([
322
+ AudioDiTBlock(
323
+ dim=embed_dim,
324
+ time_aligned_context_dim=time_aligned_context_dim,
325
+ dilation=2**(i % dilation_cycle_length),
326
+ context_dim=context_dim,
327
+ num_heads=num_heads,
328
+ mlp_ratio=mlp_ratio,
329
+ qkv_bias=qkv_bias,
330
+ qk_scale=qk_scale,
331
+ qk_norm=qk_norm,
332
+ act_layer=act_layer,
333
+ norm_layer=norm_layer,
334
+ time_fusion=time_fusion,
335
+ ada_sola_rank=ada_sola_rank,
336
+ ada_sola_alpha=ada_sola_alpha,
337
+ skip=False,
338
+ skip_norm=False,
339
+ rope_mode=self.rope,
340
+ context_norm=context_norm,
341
+ use_checkpoint=use_checkpoint
342
+ ) for i in range(depth // 2)
343
+ ])
344
+
345
+ self.mid_block = AudioDiTBlock(
346
+ dim=embed_dim,
347
+ time_aligned_context_dim=time_aligned_context_dim,
348
+ dilation=1,
349
+ context_dim=context_dim,
350
+ num_heads=num_heads,
351
+ mlp_ratio=mlp_ratio,
352
+ qkv_bias=qkv_bias,
353
+ qk_scale=qk_scale,
354
+ qk_norm=qk_norm,
355
+ act_layer=act_layer,
356
+ norm_layer=norm_layer,
357
+ time_fusion=time_fusion,
358
+ ada_sola_rank=ada_sola_rank,
359
+ ada_sola_alpha=ada_sola_alpha,
360
+ skip=False,
361
+ skip_norm=False,
362
+ rope_mode=self.rope,
363
+ context_norm=context_norm,
364
+ use_checkpoint=use_checkpoint
365
+ )
366
+
367
+ self.out_blocks = nn.ModuleList([
368
+ AudioDiTBlock(
369
+ dim=embed_dim,
370
+ time_aligned_context_dim=time_aligned_context_dim,
371
+ dilation=2**(i % dilation_cycle_length),
372
+ context_dim=context_dim,
373
+ num_heads=num_heads,
374
+ mlp_ratio=mlp_ratio,
375
+ qkv_bias=qkv_bias,
376
+ qk_scale=qk_scale,
377
+ qk_norm=qk_norm,
378
+ act_layer=act_layer,
379
+ norm_layer=norm_layer,
380
+ time_fusion=time_fusion,
381
+ ada_sola_rank=ada_sola_rank,
382
+ ada_sola_alpha=ada_sola_alpha,
383
+ skip=skip,
384
+ skip_norm=skip_norm,
385
+ rope_mode=self.rope,
386
+ context_norm=context_norm,
387
+ use_checkpoint=use_checkpoint
388
+ ) for i in range(depth // 2)
389
+ ])
390
+
391
+ # FinalLayer block
392
+ self.use_conv = use_conv
393
+ self.final_block = FinalBlock(
394
+ embed_dim=embed_dim,
395
+ patch_size=patch_size,
396
+ img_size=img_size,
397
+ in_chans=out_chans,
398
+ input_type=input_type,
399
+ norm_layer=norm_layer,
400
+ use_conv=use_conv,
401
+ use_adanorm=self.use_adanorm
402
+ )
403
+ self.initialize_weights()
404
+
405
+ def forward(
406
+ self,
407
+ x,
408
+ timesteps,
409
+ time_aligned_context,
410
+ context,
411
+ x_mask=None,
412
+ context_mask=None,
413
+ cls_token=None,
414
+ controlnet_skips=None,
415
+ ):
416
+ # make it compatible with int time step during inference
417
+ if timesteps.dim() == 0:
418
+ timesteps = timesteps.expand(x.shape[0]
419
+ ).to(x.device, dtype=torch.long)
420
+
421
+ x = self.patch_embed(x)
422
+ x = self.x_pe(x)
423
+
424
+ B, L, D = x.shape
425
+
426
+ if self.use_context:
427
+ context_token = self.context_embed(context)
428
+ context_token = self.context_pe(context_token)
429
+ if self.context_fusion == 'concat' or self.context_fusion == 'joint':
430
+ x, x_mask = self._concat_x_context(
431
+ x=x,
432
+ context=context_token,
433
+ x_mask=x_mask,
434
+ context_mask=context_mask
435
+ )
436
+ context_token, context_mask = None, None
437
+ else:
438
+ context_token, context_mask = None, None
439
+
440
+ time_token = self.time_embed(timesteps)
441
+ if self.cls_embed:
442
+ cls_token = self.cls_embed(cls_token)
443
+ time_ada = None
444
+ time_ada_final = None
445
+ if self.use_adanorm:
446
+ if self.cls_embed:
447
+ time_token = time_token + cls_token
448
+ time_token = self.time_act(time_token)
449
+ time_ada_final = self.time_ada_final(time_token)
450
+ if self.time_ada is not None:
451
+ time_ada = self.time_ada(time_token)
452
+ else:
453
+ time_token = time_token.unsqueeze(dim=1)
454
+ if self.cls_embed:
455
+ cls_token = cls_token.unsqueeze(dim=1)
456
+ time_token = torch.cat([time_token, cls_token], dim=1)
457
+ time_token = self.time_pe(time_token)
458
+ x = torch.cat((time_token, x), dim=1)
459
+ if x_mask is not None:
460
+ x_mask = torch.cat([
461
+ torch.ones(B, time_token.shape[1],
462
+ device=x_mask.device).bool(), x_mask
463
+ ],
464
+ dim=1)
465
+ time_token = None
466
+
467
+ skips = []
468
+ for blk in self.in_blocks:
469
+ x = blk(
470
+ x=x,
471
+ time_aligned_context=time_aligned_context,
472
+ time_token=time_token,
473
+ time_ada=time_ada,
474
+ skip=None,
475
+ context=context_token,
476
+ x_mask=x_mask,
477
+ context_mask=context_mask,
478
+ extras=self.extras
479
+ )
480
+ if self.use_skip:
481
+ skips.append(x)
482
+
483
+ x = self.mid_block(
484
+ x=x,
485
+ time_aligned_context=time_aligned_context,
486
+ time_token=time_token,
487
+ time_ada=time_ada,
488
+ skip=None,
489
+ context=context_token,
490
+ x_mask=x_mask,
491
+ context_mask=context_mask,
492
+ extras=self.extras
493
+ )
494
+ for blk in self.out_blocks:
495
+ if self.use_skip:
496
+ skip = skips.pop()
497
+ if controlnet_skips:
498
+ # add to skip like u-net controlnet
499
+ skip = skip + controlnet_skips.pop()
500
+ else:
501
+ skip = None
502
+ if controlnet_skips:
503
+ # directly add to x
504
+ x = x + controlnet_skips.pop()
505
+
506
+ x = blk(
507
+ x=x,
508
+ time_aligned_context=time_aligned_context,
509
+ time_token=time_token,
510
+ time_ada=time_ada,
511
+ skip=skip,
512
+ context=context_token,
513
+ x_mask=x_mask,
514
+ context_mask=context_mask,
515
+ extras=self.extras
516
+ )
517
+
518
+ x = self.final_block(x, time_ada=time_ada_final, extras=self.extras)
519
+
520
+ return x
models/dit/audio_dit.py ADDED
@@ -0,0 +1,652 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.checkpoint import checkpoint
4
+
5
+ from .mask_dit import DiTBlock, FinalBlock, UDiT
6
+ from .modules import (
7
+ film_modulate,
8
+ PatchEmbed,
9
+ PE_wrapper,
10
+ TimestepEmbedder,
11
+ RMSNorm,
12
+ )
13
+
14
+
15
+ class LayerFusionDiTBlock(DiTBlock):
16
+ """
17
+ A modified DiT block with time aligned context add to latent.
18
+ """
19
+ def __init__(
20
+ self,
21
+ dim,
22
+ ta_context_dim,
23
+ ta_context_norm=False,
24
+ context_dim=None,
25
+ num_heads=8,
26
+ mlp_ratio=4.,
27
+ qkv_bias=False,
28
+ qk_scale=None,
29
+ qk_norm=None,
30
+ act_layer='gelu',
31
+ norm_layer=nn.LayerNorm,
32
+ ta_context_fusion='add',
33
+ time_fusion='none',
34
+ ada_sola_rank=None,
35
+ ada_sola_alpha=None,
36
+ skip=False,
37
+ skip_norm=False,
38
+ rope_mode='none',
39
+ context_norm=False,
40
+ use_checkpoint=False
41
+ ):
42
+ super().__init__(
43
+ dim=dim,
44
+ context_dim=context_dim,
45
+ num_heads=num_heads,
46
+ mlp_ratio=mlp_ratio,
47
+ qkv_bias=qkv_bias,
48
+ qk_scale=qk_scale,
49
+ qk_norm=qk_norm,
50
+ act_layer=act_layer,
51
+ norm_layer=norm_layer,
52
+ time_fusion=time_fusion,
53
+ ada_sola_rank=ada_sola_rank,
54
+ ada_sola_alpha=ada_sola_alpha,
55
+ skip=skip,
56
+ skip_norm=skip_norm,
57
+ rope_mode=rope_mode,
58
+ context_norm=context_norm,
59
+ use_checkpoint=use_checkpoint
60
+ )
61
+ self.ta_context_fusion = ta_context_fusion
62
+ self.ta_context_norm = ta_context_norm
63
+ if self.ta_context_fusion == "add":
64
+ self.ta_context_projection = nn.Linear(
65
+ ta_context_dim, dim, bias=False
66
+ )
67
+ self.ta_context_norm = norm_layer(
68
+ ta_context_dim
69
+ ) if self.ta_context_norm else nn.Identity()
70
+ elif self.ta_context_fusion == "concat":
71
+ self.ta_context_projection = nn.Linear(ta_context_dim + dim, dim)
72
+ self.ta_context_norm = norm_layer(
73
+ ta_context_dim + dim
74
+ ) if self.ta_context_norm else nn.Identity()
75
+
76
+ def forward(
77
+ self,
78
+ x,
79
+ time_aligned_context,
80
+ time_token=None,
81
+ time_ada=None,
82
+ skip=None,
83
+ context=None,
84
+ x_mask=None,
85
+ context_mask=None,
86
+ extras=None
87
+ ):
88
+ if self.use_checkpoint:
89
+ return checkpoint(
90
+ self._forward,
91
+ x,
92
+ time_aligned_context,
93
+ time_token,
94
+ time_ada,
95
+ skip,
96
+ context,
97
+ x_mask,
98
+ context_mask,
99
+ extras,
100
+ use_reentrant=False
101
+ )
102
+ else:
103
+ return self._forward(
104
+ x,
105
+ time_aligned_context,
106
+ time_token,
107
+ time_ada,
108
+ skip,
109
+ context,
110
+ x_mask,
111
+ context_mask,
112
+ extras,
113
+ )
114
+
115
+ def _forward(
116
+ self,
117
+ x,
118
+ time_aligned_context,
119
+ time_token=None,
120
+ time_ada=None,
121
+ skip=None,
122
+ context=None,
123
+ x_mask=None,
124
+ context_mask=None,
125
+ extras=None
126
+ ):
127
+ B, T, C = x.shape
128
+
129
+ # skip connection
130
+ if self.skip_linear is not None:
131
+ assert skip is not None
132
+ cat = torch.cat([x, skip], dim=-1)
133
+ cat = self.skip_norm(cat)
134
+ x = self.skip_linear(cat)
135
+
136
+ if self.use_adanorm:
137
+ time_ada = self.adaln(time_token, time_ada)
138
+ (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp,
139
+ gate_mlp) = time_ada.chunk(6, dim=1)
140
+
141
+ # self attention
142
+ if self.use_adanorm:
143
+ x_norm = film_modulate(
144
+ self.norm1(x), shift=shift_msa, scale=scale_msa
145
+ )
146
+ tanh_gate_msa = torch.tanh(1 - gate_msa)
147
+ x = x + tanh_gate_msa * self.attn(
148
+ x_norm, context=None, context_mask=x_mask, extras=extras
149
+ )
150
+ # x = x + (1 - gate_msa) * self.attn(
151
+ # x_norm, context=None, context_mask=x_mask, extras=extras
152
+ # )
153
+ else:
154
+ # TODO diffusion timestep input is not fused here
155
+ x = x + self.attn(
156
+ self.norm1(x),
157
+ context=None,
158
+ context_mask=x_mask,
159
+ extras=extras
160
+ )
161
+
162
+ # time aligned context fusion
163
+ if self.ta_context_fusion == "add":
164
+ time_aligned_context = self.ta_context_projection(
165
+ self.ta_context_norm(time_aligned_context)
166
+ )
167
+ if time_aligned_context.size(1) < x.size(1):
168
+ time_aligned_context = nn.functional.pad(
169
+ time_aligned_context, (0, 0, 1, 0)
170
+ )
171
+ x = x + time_aligned_context
172
+ elif self.ta_context_fusion == "concat":
173
+ if time_aligned_context.size(1) < x.size(1):
174
+ time_aligned_context = nn.functional.pad(
175
+ time_aligned_context, (0, 0, 1, 0)
176
+ )
177
+ cat = torch.cat([x, time_aligned_context], dim=-1)
178
+ cat = self.ta_context_norm(cat)
179
+ x = self.ta_context_projection(cat)
180
+
181
+ # cross attention
182
+ if self.use_context:
183
+ assert context is not None
184
+ x = x + self.cross_attn(
185
+ x=self.norm2(x),
186
+ context=self.norm_context(context),
187
+ context_mask=context_mask,
188
+ extras=extras
189
+ )
190
+
191
+ # mlp
192
+ if self.use_adanorm:
193
+ x_norm = film_modulate(
194
+ self.norm3(x), shift=shift_mlp, scale=scale_mlp
195
+ )
196
+ x = x + (1 - gate_mlp) * self.mlp(x_norm)
197
+ else:
198
+ x = x + self.mlp(self.norm3(x))
199
+
200
+ return x
201
+
202
+
203
+ class LayerFusionAudioDiT(UDiT):
204
+ def __init__(
205
+ self,
206
+ img_size=224,
207
+ patch_size=16,
208
+ in_chans=3,
209
+ input_type='2d',
210
+ out_chans=None,
211
+ embed_dim=768,
212
+ depth=12,
213
+ num_heads=12,
214
+ mlp_ratio=4,
215
+ qkv_bias=False,
216
+ qk_scale=None,
217
+ qk_norm=None,
218
+ act_layer='gelu',
219
+ norm_layer='layernorm',
220
+ context_norm=False,
221
+ use_checkpoint=False,
222
+ time_fusion='token',
223
+ ada_sola_rank=None,
224
+ ada_sola_alpha=None,
225
+ cls_dim=None,
226
+ ta_context_dim=768,
227
+ ta_context_fusion='concat',
228
+ ta_context_norm=True,
229
+ context_dim=768,
230
+ context_fusion='concat',
231
+ context_max_length=128,
232
+ context_pe_method='sinu',
233
+ pe_method='abs',
234
+ rope_mode='none',
235
+ use_conv=True,
236
+ skip=True,
237
+ skip_norm=True
238
+ ):
239
+ nn.Module.__init__(self)
240
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
241
+
242
+ # input
243
+ self.in_chans = in_chans
244
+ self.input_type = input_type
245
+ if self.input_type == '2d':
246
+ num_patches = (img_size[0] //
247
+ patch_size) * (img_size[1] // patch_size)
248
+ elif self.input_type == '1d':
249
+ num_patches = img_size // patch_size
250
+ self.patch_embed = PatchEmbed(
251
+ patch_size=patch_size,
252
+ in_chans=in_chans,
253
+ embed_dim=embed_dim,
254
+ input_type=input_type
255
+ )
256
+ out_chans = in_chans if out_chans is None else out_chans
257
+ self.out_chans = out_chans
258
+
259
+ # position embedding
260
+ self.rope = rope_mode
261
+ self.x_pe = PE_wrapper(
262
+ dim=embed_dim, method=pe_method, length=num_patches
263
+ )
264
+
265
+ # time embed
266
+ self.time_embed = TimestepEmbedder(embed_dim)
267
+ self.time_fusion = time_fusion
268
+ self.use_adanorm = False
269
+
270
+ # cls embed
271
+ if cls_dim is not None:
272
+ self.cls_embed = nn.Sequential(
273
+ nn.Linear(cls_dim, embed_dim, bias=True),
274
+ nn.SiLU(),
275
+ nn.Linear(embed_dim, embed_dim, bias=True),
276
+ )
277
+ else:
278
+ self.cls_embed = None
279
+
280
+ # time fusion
281
+ if time_fusion == 'token':
282
+ # put token at the beginning of sequence
283
+ self.extras = 2 if self.cls_embed else 1
284
+ self.time_pe = PE_wrapper(
285
+ dim=embed_dim, method='abs', length=self.extras
286
+ )
287
+ elif time_fusion in ['ada', 'ada_single', 'ada_sola', 'ada_sola_bias']:
288
+ self.use_adanorm = True
289
+ # aviod repetitive silu for each adaln block
290
+ self.time_act = nn.SiLU()
291
+ self.extras = 0
292
+ self.time_ada_final = nn.Linear(
293
+ embed_dim, 2 * embed_dim, bias=True
294
+ )
295
+ if time_fusion in ['ada_single', 'ada_sola', 'ada_sola_bias']:
296
+ # shared adaln
297
+ self.time_ada = nn.Linear(embed_dim, 6 * embed_dim, bias=True)
298
+ else:
299
+ self.time_ada = None
300
+ else:
301
+ raise NotImplementedError
302
+
303
+ # context
304
+ # use a simple projection
305
+ self.use_context = False
306
+ self.context_cross = False
307
+ self.context_max_length = context_max_length
308
+ self.context_fusion = 'none'
309
+ if context_dim is not None:
310
+ self.use_context = True
311
+ self.context_embed = nn.Sequential(
312
+ nn.Linear(context_dim, embed_dim, bias=True),
313
+ nn.SiLU(),
314
+ nn.Linear(embed_dim, embed_dim, bias=True),
315
+ )
316
+ self.context_fusion = context_fusion
317
+ if context_fusion == 'concat' or context_fusion == 'joint':
318
+ self.extras += context_max_length
319
+ self.context_pe = PE_wrapper(
320
+ dim=embed_dim,
321
+ method=context_pe_method,
322
+ length=context_max_length
323
+ )
324
+ # no cross attention layers
325
+ context_dim = None
326
+ elif context_fusion == 'cross':
327
+ self.context_pe = PE_wrapper(
328
+ dim=embed_dim,
329
+ method=context_pe_method,
330
+ length=context_max_length
331
+ )
332
+ self.context_cross = True
333
+ context_dim = embed_dim
334
+ else:
335
+ raise NotImplementedError
336
+
337
+ self.use_skip = skip
338
+
339
+ # norm layers
340
+ if norm_layer == 'layernorm':
341
+ norm_layer = nn.LayerNorm
342
+ elif norm_layer == 'rmsnorm':
343
+ norm_layer = RMSNorm
344
+ else:
345
+ raise NotImplementedError
346
+
347
+ self.in_blocks = nn.ModuleList([
348
+ LayerFusionDiTBlock(
349
+ dim=embed_dim,
350
+ ta_context_dim=ta_context_dim,
351
+ ta_context_fusion=ta_context_fusion,
352
+ ta_context_norm=ta_context_norm,
353
+ context_dim=context_dim,
354
+ num_heads=num_heads,
355
+ mlp_ratio=mlp_ratio,
356
+ qkv_bias=qkv_bias,
357
+ qk_scale=qk_scale,
358
+ qk_norm=qk_norm,
359
+ act_layer=act_layer,
360
+ norm_layer=norm_layer,
361
+ time_fusion=time_fusion,
362
+ ada_sola_rank=ada_sola_rank,
363
+ ada_sola_alpha=ada_sola_alpha,
364
+ skip=False,
365
+ skip_norm=False,
366
+ rope_mode=self.rope,
367
+ context_norm=context_norm,
368
+ use_checkpoint=use_checkpoint
369
+ ) for i in range(depth // 2)
370
+ ])
371
+
372
+ self.mid_block = LayerFusionDiTBlock(
373
+ dim=embed_dim,
374
+ ta_context_dim=ta_context_dim,
375
+ context_dim=context_dim,
376
+ num_heads=num_heads,
377
+ mlp_ratio=mlp_ratio,
378
+ qkv_bias=qkv_bias,
379
+ qk_scale=qk_scale,
380
+ qk_norm=qk_norm,
381
+ act_layer=act_layer,
382
+ norm_layer=norm_layer,
383
+ time_fusion=time_fusion,
384
+ ada_sola_rank=ada_sola_rank,
385
+ ada_sola_alpha=ada_sola_alpha,
386
+ ta_context_fusion=ta_context_fusion,
387
+ ta_context_norm=ta_context_norm,
388
+ skip=False,
389
+ skip_norm=False,
390
+ rope_mode=self.rope,
391
+ context_norm=context_norm,
392
+ use_checkpoint=use_checkpoint
393
+ )
394
+
395
+ self.out_blocks = nn.ModuleList([
396
+ LayerFusionDiTBlock(
397
+ dim=embed_dim,
398
+ ta_context_dim=ta_context_dim,
399
+ context_dim=context_dim,
400
+ num_heads=num_heads,
401
+ mlp_ratio=mlp_ratio,
402
+ qkv_bias=qkv_bias,
403
+ qk_scale=qk_scale,
404
+ qk_norm=qk_norm,
405
+ act_layer=act_layer,
406
+ norm_layer=norm_layer,
407
+ time_fusion=time_fusion,
408
+ ada_sola_rank=ada_sola_rank,
409
+ ada_sola_alpha=ada_sola_alpha,
410
+ ta_context_fusion=ta_context_fusion,
411
+ ta_context_norm=ta_context_norm,
412
+ skip=skip,
413
+ skip_norm=skip_norm,
414
+ rope_mode=self.rope,
415
+ context_norm=context_norm,
416
+ use_checkpoint=use_checkpoint
417
+ ) for i in range(depth // 2)
418
+ ])
419
+
420
+ # FinalLayer block
421
+ self.use_conv = use_conv
422
+ self.final_block = FinalBlock(
423
+ embed_dim=embed_dim,
424
+ patch_size=patch_size,
425
+ img_size=img_size,
426
+ in_chans=out_chans,
427
+ input_type=input_type,
428
+ norm_layer=norm_layer,
429
+ use_conv=use_conv,
430
+ use_adanorm=self.use_adanorm
431
+ )
432
+ self.initialize_weights()
433
+
434
+ def forward(
435
+ self,
436
+ x,
437
+ timesteps,
438
+ time_aligned_context,
439
+ context,
440
+ x_mask=None,
441
+ context_mask=None,
442
+ cls_token=None,
443
+ controlnet_skips=None,
444
+ ):
445
+ # make it compatible with int time step during inference
446
+ if timesteps.dim() == 0:
447
+ timesteps = timesteps.expand(x.shape[0]
448
+ ).to(x.device, dtype=torch.long)
449
+
450
+ x = self.patch_embed(x)
451
+ x = self.x_pe(x)
452
+
453
+ B, L, D = x.shape
454
+
455
+ if self.use_context:
456
+ context_token = self.context_embed(context)
457
+ context_token = self.context_pe(context_token)
458
+ if self.context_fusion == 'concat' or self.context_fusion == 'joint':
459
+ x, x_mask = self._concat_x_context(
460
+ x=x,
461
+ context=context_token,
462
+ x_mask=x_mask,
463
+ context_mask=context_mask
464
+ )
465
+ context_token, context_mask = None, None
466
+ else:
467
+ context_token, context_mask = None, None
468
+
469
+ time_token = self.time_embed(timesteps)
470
+ if self.cls_embed:
471
+ cls_token = self.cls_embed(cls_token)
472
+ time_ada = None
473
+ time_ada_final = None
474
+ if self.use_adanorm:
475
+ if self.cls_embed:
476
+ time_token = time_token + cls_token
477
+ time_token = self.time_act(time_token)
478
+ time_ada_final = self.time_ada_final(time_token)
479
+ if self.time_ada is not None:
480
+ time_ada = self.time_ada(time_token)
481
+ else:
482
+ time_token = time_token.unsqueeze(dim=1)
483
+ if self.cls_embed:
484
+ cls_token = cls_token.unsqueeze(dim=1)
485
+ time_token = torch.cat([time_token, cls_token], dim=1)
486
+ time_token = self.time_pe(time_token)
487
+ x = torch.cat((time_token, x), dim=1)
488
+ if x_mask is not None:
489
+ x_mask = torch.cat([
490
+ torch.ones(B, time_token.shape[1],
491
+ device=x_mask.device).bool(), x_mask
492
+ ],
493
+ dim=1)
494
+ time_token = None
495
+
496
+ skips = []
497
+ for blk_idx, blk in enumerate(self.in_blocks):
498
+ x = blk(
499
+ x=x,
500
+ time_aligned_context=time_aligned_context,
501
+ time_token=time_token,
502
+ time_ada=time_ada,
503
+ skip=None,
504
+ context=context_token,
505
+ x_mask=x_mask,
506
+ context_mask=context_mask,
507
+ extras=self.extras
508
+ )
509
+ # if not self.training:
510
+ # print(
511
+ # f"in block {blk_idx}, min: {x.min().item()}, max: {x.max().item()}, std: {x.std().item()}"
512
+ # )
513
+ if self.use_skip:
514
+ skips.append(x)
515
+
516
+ x = self.mid_block(
517
+ x=x,
518
+ time_aligned_context=time_aligned_context,
519
+ time_token=time_token,
520
+ time_ada=time_ada,
521
+ skip=None,
522
+ context=context_token,
523
+ x_mask=x_mask,
524
+ context_mask=context_mask,
525
+ extras=self.extras
526
+ )
527
+ for blk_idx, blk in enumerate(self.out_blocks):
528
+ if self.use_skip:
529
+ skip = skips.pop()
530
+ if controlnet_skips:
531
+ # add to skip like u-net controlnet
532
+ skip = skip + controlnet_skips.pop()
533
+ else:
534
+ skip = None
535
+ if controlnet_skips:
536
+ # directly add to x
537
+ x = x + controlnet_skips.pop()
538
+
539
+ x = blk(
540
+ x=x,
541
+ time_aligned_context=time_aligned_context,
542
+ time_token=time_token,
543
+ time_ada=time_ada,
544
+ skip=skip,
545
+ context=context_token,
546
+ x_mask=x_mask,
547
+ context_mask=context_mask,
548
+ extras=self.extras
549
+ )
550
+ # if not self.training:
551
+ # print(
552
+ # f"out block {blk_idx}, min: {x.min().item()}, max: {x.max().item()}, std: {x.std().item()}"
553
+ # )
554
+
555
+ x = self.final_block(x, time_ada=time_ada_final, extras=self.extras)
556
+
557
+ return x
558
+
559
+
560
+ class InputFusionAudioDiT(UDiT):
561
+ def __init__(
562
+ self,
563
+ img_size=224,
564
+ patch_size=16,
565
+ in_chans=3,
566
+ input_type='2d',
567
+ out_chans=None,
568
+ embed_dim=768,
569
+ depth=12,
570
+ num_heads=12,
571
+ mlp_ratio=4,
572
+ qkv_bias=False,
573
+ qk_scale=None,
574
+ qk_norm=None,
575
+ act_layer='gelu',
576
+ norm_layer='layernorm',
577
+ context_norm=False,
578
+ use_checkpoint=False,
579
+ time_fusion='token',
580
+ ada_sola_rank=None,
581
+ ada_sola_alpha=None,
582
+ cls_dim=None,
583
+ ta_context_dim=768,
584
+ context_dim=768,
585
+ context_fusion='concat',
586
+ context_max_length=128,
587
+ context_pe_method='sinu',
588
+ pe_method='abs',
589
+ rope_mode='none',
590
+ use_conv=True,
591
+ skip=True,
592
+ skip_norm=True
593
+ ):
594
+ super().__init__(
595
+ img_size,
596
+ patch_size,
597
+ in_chans,
598
+ input_type,
599
+ out_chans,
600
+ embed_dim,
601
+ depth,
602
+ num_heads,
603
+ mlp_ratio,
604
+ qkv_bias,
605
+ qk_scale,
606
+ qk_norm,
607
+ act_layer,
608
+ norm_layer,
609
+ context_norm,
610
+ use_checkpoint,
611
+ time_fusion,
612
+ ada_sola_rank,
613
+ ada_sola_alpha,
614
+ cls_dim,
615
+ context_dim,
616
+ context_fusion,
617
+ context_max_length,
618
+ context_pe_method,
619
+ pe_method,
620
+ rope_mode,
621
+ use_conv,
622
+ skip,
623
+ skip_norm,
624
+ )
625
+ self.input_proj = nn.Linear(in_chans + ta_context_dim, in_chans)
626
+ nn.init.xavier_uniform_(self.input_proj.weight)
627
+ nn.init.constant_(self.input_proj.bias, 0)
628
+
629
+ def forward(
630
+ self,
631
+ x,
632
+ timesteps,
633
+ time_aligned_context,
634
+ context,
635
+ x_mask=None,
636
+ context_mask=None,
637
+ cls_token=None,
638
+ controlnet_skips=None
639
+ ):
640
+ x = self.input_proj(
641
+ torch.cat([x.transpose(1, 2), time_aligned_context], dim=-1)
642
+ )
643
+ x = x.transpose(1, 2)
644
+ return super().forward(
645
+ x=x,
646
+ timesteps=timesteps,
647
+ context=context,
648
+ x_mask=x_mask,
649
+ context_mask=context_mask,
650
+ cls_token=cls_token,
651
+ controlnet_skips=controlnet_skips
652
+ )
models/dit/mask_dit.py ADDED
@@ -0,0 +1,823 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.utils.checkpoint import checkpoint
6
+
7
+ from .modules import (
8
+ film_modulate,
9
+ unpatchify,
10
+ PatchEmbed,
11
+ PE_wrapper,
12
+ TimestepEmbedder,
13
+ FeedForward,
14
+ RMSNorm,
15
+ )
16
+ from .span_mask import compute_mask_indices
17
+ from .attention import Attention
18
+
19
+ logger = logging.Logger(__file__)
20
+
21
+
22
+ class AdaLN(nn.Module):
23
+ def __init__(self, dim, ada_mode='ada', r=None, alpha=None):
24
+ super().__init__()
25
+ self.ada_mode = ada_mode
26
+ self.scale_shift_table = None
27
+ if ada_mode == 'ada':
28
+ # move nn.silu outside
29
+ self.time_ada = nn.Linear(dim, 6 * dim, bias=True)
30
+ elif ada_mode == 'ada_single':
31
+ # adaln used in pixel-art alpha
32
+ self.scale_shift_table = nn.Parameter(torch.zeros(6, dim))
33
+ elif ada_mode in ['ada_sola', 'ada_sola_bias']:
34
+ self.lora_a = nn.Linear(dim, r * 6, bias=False)
35
+ self.lora_b = nn.Linear(r * 6, dim * 6, bias=False)
36
+ self.scaling = alpha / r
37
+ if ada_mode == 'ada_sola_bias':
38
+ # take bias out for consistency
39
+ self.scale_shift_table = nn.Parameter(torch.zeros(6, dim))
40
+ else:
41
+ raise NotImplementedError
42
+
43
+ def forward(self, time_token=None, time_ada=None):
44
+ if self.ada_mode == 'ada':
45
+ assert time_ada is None
46
+ B = time_token.shape[0]
47
+ time_ada = self.time_ada(time_token).reshape(B, 6, -1)
48
+ elif self.ada_mode == 'ada_single':
49
+ B = time_ada.shape[0]
50
+ time_ada = time_ada.reshape(B, 6, -1)
51
+ time_ada = self.scale_shift_table[None] + time_ada
52
+ elif self.ada_mode in ['ada_sola', 'ada_sola_bias']:
53
+ B = time_ada.shape[0]
54
+ time_ada_lora = self.lora_b(self.lora_a(time_token)) * self.scaling
55
+ time_ada = time_ada + time_ada_lora
56
+ time_ada = time_ada.reshape(B, 6, -1)
57
+ if self.scale_shift_table is not None:
58
+ time_ada = self.scale_shift_table[None] + time_ada
59
+ else:
60
+ raise NotImplementedError
61
+ return time_ada
62
+
63
+
64
+ class DiTBlock(nn.Module):
65
+ """
66
+ A modified PixArt block with adaptive layer norm (adaLN-single) conditioning.
67
+ """
68
+ def __init__(
69
+ self,
70
+ dim,
71
+ context_dim=None,
72
+ num_heads=8,
73
+ mlp_ratio=4.,
74
+ qkv_bias=False,
75
+ qk_scale=None,
76
+ qk_norm=None,
77
+ act_layer='gelu',
78
+ norm_layer=nn.LayerNorm,
79
+ time_fusion='none',
80
+ ada_sola_rank=None,
81
+ ada_sola_alpha=None,
82
+ skip=False,
83
+ skip_norm=False,
84
+ rope_mode='none',
85
+ context_norm=False,
86
+ use_checkpoint=False
87
+ ):
88
+
89
+ super().__init__()
90
+ self.norm1 = norm_layer(dim)
91
+ self.attn = Attention(
92
+ dim=dim,
93
+ num_heads=num_heads,
94
+ qkv_bias=qkv_bias,
95
+ qk_scale=qk_scale,
96
+ qk_norm=qk_norm,
97
+ rope_mode=rope_mode
98
+ )
99
+
100
+ if context_dim is not None:
101
+ self.use_context = True
102
+ self.cross_attn = Attention(
103
+ dim=dim,
104
+ num_heads=num_heads,
105
+ context_dim=context_dim,
106
+ qkv_bias=qkv_bias,
107
+ qk_scale=qk_scale,
108
+ qk_norm=qk_norm,
109
+ rope_mode='none'
110
+ )
111
+ self.norm2 = norm_layer(dim)
112
+ if context_norm:
113
+ self.norm_context = norm_layer(context_dim)
114
+ else:
115
+ self.norm_context = nn.Identity()
116
+ else:
117
+ self.use_context = False
118
+
119
+ self.norm3 = norm_layer(dim)
120
+ self.mlp = FeedForward(
121
+ dim=dim, mult=mlp_ratio, activation_fn=act_layer, dropout=0
122
+ )
123
+
124
+ self.use_adanorm = True if time_fusion != 'token' else False
125
+ if self.use_adanorm:
126
+ self.adaln = AdaLN(
127
+ dim,
128
+ ada_mode=time_fusion,
129
+ r=ada_sola_rank,
130
+ alpha=ada_sola_alpha
131
+ )
132
+ if skip:
133
+ self.skip_norm = norm_layer(2 *
134
+ dim) if skip_norm else nn.Identity()
135
+ self.skip_linear = nn.Linear(2 * dim, dim)
136
+ else:
137
+ self.skip_linear = None
138
+
139
+ self.use_checkpoint = use_checkpoint
140
+
141
+ def forward(
142
+ self,
143
+ x,
144
+ time_token=None,
145
+ time_ada=None,
146
+ skip=None,
147
+ context=None,
148
+ x_mask=None,
149
+ context_mask=None,
150
+ extras=None
151
+ ):
152
+ if self.use_checkpoint:
153
+ return checkpoint(
154
+ self._forward,
155
+ x,
156
+ time_token,
157
+ time_ada,
158
+ skip,
159
+ context,
160
+ x_mask,
161
+ context_mask,
162
+ extras,
163
+ use_reentrant=False
164
+ )
165
+ else:
166
+ return self._forward(
167
+ x, time_token, time_ada, skip, context, x_mask, context_mask,
168
+ extras
169
+ )
170
+
171
+ def _forward(
172
+ self,
173
+ x,
174
+ time_token=None,
175
+ time_ada=None,
176
+ skip=None,
177
+ context=None,
178
+ x_mask=None,
179
+ context_mask=None,
180
+ extras=None
181
+ ):
182
+ B, T, C = x.shape
183
+ if self.skip_linear is not None:
184
+ assert skip is not None
185
+ cat = torch.cat([x, skip], dim=-1)
186
+ cat = self.skip_norm(cat)
187
+ x = self.skip_linear(cat)
188
+
189
+ if self.use_adanorm:
190
+ time_ada = self.adaln(time_token, time_ada)
191
+ (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp,
192
+ gate_mlp) = time_ada.chunk(6, dim=1)
193
+
194
+ # self attention
195
+ if self.use_adanorm:
196
+ x_norm = film_modulate(
197
+ self.norm1(x), shift=shift_msa, scale=scale_msa
198
+ )
199
+ x = x + (1 - gate_msa) * self.attn(
200
+ x_norm, context=None, context_mask=x_mask, extras=extras
201
+ )
202
+ else:
203
+ x = x + self.attn(
204
+ self.norm1(x),
205
+ context=None,
206
+ context_mask=x_mask,
207
+ extras=extras
208
+ )
209
+
210
+ # cross attention
211
+ if self.use_context:
212
+ assert context is not None
213
+ x = x + self.cross_attn(
214
+ x=self.norm2(x),
215
+ context=self.norm_context(context),
216
+ context_mask=context_mask,
217
+ extras=extras
218
+ )
219
+
220
+ # mlp
221
+ if self.use_adanorm:
222
+ x_norm = film_modulate(
223
+ self.norm3(x), shift=shift_mlp, scale=scale_mlp
224
+ )
225
+ x = x + (1 - gate_mlp) * self.mlp(x_norm)
226
+ else:
227
+ x = x + self.mlp(self.norm3(x))
228
+
229
+ return x
230
+
231
+
232
+ class FinalBlock(nn.Module):
233
+ def __init__(
234
+ self,
235
+ embed_dim,
236
+ patch_size,
237
+ in_chans,
238
+ img_size,
239
+ input_type='2d',
240
+ norm_layer=nn.LayerNorm,
241
+ use_conv=True,
242
+ use_adanorm=True
243
+ ):
244
+ super().__init__()
245
+ self.in_chans = in_chans
246
+ self.img_size = img_size
247
+ self.input_type = input_type
248
+
249
+ self.norm = norm_layer(embed_dim)
250
+ if use_adanorm:
251
+ self.use_adanorm = True
252
+ else:
253
+ self.use_adanorm = False
254
+
255
+ if input_type == '2d':
256
+ self.patch_dim = patch_size**2 * in_chans
257
+ self.linear = nn.Linear(embed_dim, self.patch_dim, bias=True)
258
+ if use_conv:
259
+ self.final_layer = nn.Conv2d(
260
+ self.in_chans, self.in_chans, 3, padding=1
261
+ )
262
+ else:
263
+ self.final_layer = nn.Identity()
264
+
265
+ elif input_type == '1d':
266
+ self.patch_dim = patch_size * in_chans
267
+ self.linear = nn.Linear(embed_dim, self.patch_dim, bias=True)
268
+ if use_conv:
269
+ self.final_layer = nn.Conv1d(
270
+ self.in_chans, self.in_chans, 3, padding=1
271
+ )
272
+ else:
273
+ self.final_layer = nn.Identity()
274
+
275
+ def forward(self, x, time_ada=None, extras=0):
276
+ B, T, C = x.shape
277
+ x = x[:, extras:, :]
278
+ # only handle generation target
279
+ if self.use_adanorm:
280
+ shift, scale = time_ada.reshape(B, 2, -1).chunk(2, dim=1)
281
+ x = film_modulate(self.norm(x), shift, scale)
282
+ else:
283
+ x = self.norm(x)
284
+ x = self.linear(x)
285
+ x = unpatchify(x, self.in_chans, self.input_type, self.img_size)
286
+ x = self.final_layer(x)
287
+ return x
288
+
289
+
290
+ class UDiT(nn.Module):
291
+ def __init__(
292
+ self,
293
+ img_size=224,
294
+ patch_size=16,
295
+ in_chans=3,
296
+ input_type='2d',
297
+ out_chans=None,
298
+ embed_dim=768,
299
+ depth=12,
300
+ num_heads=12,
301
+ mlp_ratio=4.,
302
+ qkv_bias=False,
303
+ qk_scale=None,
304
+ qk_norm=None,
305
+ act_layer='gelu',
306
+ norm_layer='layernorm',
307
+ context_norm=False,
308
+ use_checkpoint=False,
309
+ # time fusion ada or token
310
+ time_fusion='token',
311
+ ada_sola_rank=None,
312
+ ada_sola_alpha=None,
313
+ cls_dim=None,
314
+ # max length is only used for concat
315
+ context_dim=768,
316
+ context_fusion='concat',
317
+ context_max_length=128,
318
+ context_pe_method='sinu',
319
+ pe_method='abs',
320
+ rope_mode='none',
321
+ use_conv=True,
322
+ skip=True,
323
+ skip_norm=True
324
+ ):
325
+ super().__init__()
326
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
327
+
328
+ # input
329
+ self.in_chans = in_chans
330
+ self.input_type = input_type
331
+ if self.input_type == '2d':
332
+ num_patches = (img_size[0] //
333
+ patch_size) * (img_size[1] // patch_size)
334
+ elif self.input_type == '1d':
335
+ num_patches = img_size // patch_size
336
+ self.patch_embed = PatchEmbed(
337
+ patch_size=patch_size,
338
+ in_chans=in_chans,
339
+ embed_dim=embed_dim,
340
+ input_type=input_type
341
+ )
342
+ out_chans = in_chans if out_chans is None else out_chans
343
+ self.out_chans = out_chans
344
+
345
+ # position embedding
346
+ self.rope = rope_mode
347
+ self.x_pe = PE_wrapper(
348
+ dim=embed_dim, method=pe_method, length=num_patches
349
+ )
350
+
351
+ logger.info(f'x position embedding: {pe_method}')
352
+ logger.info(f'rope mode: {self.rope}')
353
+
354
+ # time embed
355
+ self.time_embed = TimestepEmbedder(embed_dim)
356
+ self.time_fusion = time_fusion
357
+ self.use_adanorm = False
358
+
359
+ # cls embed
360
+ if cls_dim is not None:
361
+ self.cls_embed = nn.Sequential(
362
+ nn.Linear(cls_dim, embed_dim, bias=True),
363
+ nn.SiLU(),
364
+ nn.Linear(embed_dim, embed_dim, bias=True),
365
+ )
366
+ else:
367
+ self.cls_embed = None
368
+
369
+ # time fusion
370
+ if time_fusion == 'token':
371
+ # put token at the beginning of sequence
372
+ self.extras = 2 if self.cls_embed else 1
373
+ self.time_pe = PE_wrapper(
374
+ dim=embed_dim, method='abs', length=self.extras
375
+ )
376
+ elif time_fusion in ['ada', 'ada_single', 'ada_sola', 'ada_sola_bias']:
377
+ self.use_adanorm = True
378
+ # aviod repetitive silu for each adaln block
379
+ self.time_act = nn.SiLU()
380
+ self.extras = 0
381
+ self.time_ada_final = nn.Linear(
382
+ embed_dim, 2 * embed_dim, bias=True
383
+ )
384
+ if time_fusion in ['ada_single', 'ada_sola', 'ada_sola_bias']:
385
+ # shared adaln
386
+ self.time_ada = nn.Linear(embed_dim, 6 * embed_dim, bias=True)
387
+ else:
388
+ self.time_ada = None
389
+ else:
390
+ raise NotImplementedError
391
+ logger.info(f'time fusion mode: {self.time_fusion}')
392
+
393
+ # context
394
+ # use a simple projection
395
+ self.use_context = False
396
+ self.context_cross = False
397
+ self.context_max_length = context_max_length
398
+ self.context_fusion = 'none'
399
+ if context_dim is not None:
400
+ self.use_context = True
401
+ self.context_embed = nn.Sequential(
402
+ nn.Linear(context_dim, embed_dim, bias=True),
403
+ nn.SiLU(),
404
+ nn.Linear(embed_dim, embed_dim, bias=True),
405
+ )
406
+ self.context_fusion = context_fusion
407
+ if context_fusion == 'concat' or context_fusion == 'joint':
408
+ self.extras += context_max_length
409
+ self.context_pe = PE_wrapper(
410
+ dim=embed_dim,
411
+ method=context_pe_method,
412
+ length=context_max_length
413
+ )
414
+ # no cross attention layers
415
+ context_dim = None
416
+ elif context_fusion == 'cross':
417
+ self.context_pe = PE_wrapper(
418
+ dim=embed_dim,
419
+ method=context_pe_method,
420
+ length=context_max_length
421
+ )
422
+ self.context_cross = True
423
+ context_dim = embed_dim
424
+ else:
425
+ raise NotImplementedError
426
+ logger.info(f'context fusion mode: {context_fusion}')
427
+ logger.info(f'context position embedding: {context_pe_method}')
428
+
429
+ self.use_skip = skip
430
+
431
+ # norm layers
432
+ if norm_layer == 'layernorm':
433
+ norm_layer = nn.LayerNorm
434
+ elif norm_layer == 'rmsnorm':
435
+ norm_layer = RMSNorm
436
+ else:
437
+ raise NotImplementedError
438
+
439
+ logger.info(f'use long skip connection: {skip}')
440
+ self.in_blocks = nn.ModuleList([
441
+ DiTBlock(
442
+ dim=embed_dim,
443
+ context_dim=context_dim,
444
+ num_heads=num_heads,
445
+ mlp_ratio=mlp_ratio,
446
+ qkv_bias=qkv_bias,
447
+ qk_scale=qk_scale,
448
+ qk_norm=qk_norm,
449
+ act_layer=act_layer,
450
+ norm_layer=norm_layer,
451
+ time_fusion=time_fusion,
452
+ ada_sola_rank=ada_sola_rank,
453
+ ada_sola_alpha=ada_sola_alpha,
454
+ skip=False,
455
+ skip_norm=False,
456
+ rope_mode=self.rope,
457
+ context_norm=context_norm,
458
+ use_checkpoint=use_checkpoint
459
+ ) for _ in range(depth // 2)
460
+ ])
461
+
462
+ self.mid_block = DiTBlock(
463
+ dim=embed_dim,
464
+ context_dim=context_dim,
465
+ num_heads=num_heads,
466
+ mlp_ratio=mlp_ratio,
467
+ qkv_bias=qkv_bias,
468
+ qk_scale=qk_scale,
469
+ qk_norm=qk_norm,
470
+ act_layer=act_layer,
471
+ norm_layer=norm_layer,
472
+ time_fusion=time_fusion,
473
+ ada_sola_rank=ada_sola_rank,
474
+ ada_sola_alpha=ada_sola_alpha,
475
+ skip=False,
476
+ skip_norm=False,
477
+ rope_mode=self.rope,
478
+ context_norm=context_norm,
479
+ use_checkpoint=use_checkpoint
480
+ )
481
+
482
+ self.out_blocks = nn.ModuleList([
483
+ DiTBlock(
484
+ dim=embed_dim,
485
+ context_dim=context_dim,
486
+ num_heads=num_heads,
487
+ mlp_ratio=mlp_ratio,
488
+ qkv_bias=qkv_bias,
489
+ qk_scale=qk_scale,
490
+ qk_norm=qk_norm,
491
+ act_layer=act_layer,
492
+ norm_layer=norm_layer,
493
+ time_fusion=time_fusion,
494
+ ada_sola_rank=ada_sola_rank,
495
+ ada_sola_alpha=ada_sola_alpha,
496
+ skip=skip,
497
+ skip_norm=skip_norm,
498
+ rope_mode=self.rope,
499
+ context_norm=context_norm,
500
+ use_checkpoint=use_checkpoint
501
+ ) for _ in range(depth // 2)
502
+ ])
503
+
504
+ # FinalLayer block
505
+ self.use_conv = use_conv
506
+ self.final_block = FinalBlock(
507
+ embed_dim=embed_dim,
508
+ patch_size=patch_size,
509
+ img_size=img_size,
510
+ in_chans=out_chans,
511
+ input_type=input_type,
512
+ norm_layer=norm_layer,
513
+ use_conv=use_conv,
514
+ use_adanorm=self.use_adanorm
515
+ )
516
+ self.initialize_weights()
517
+
518
+ def _init_ada(self):
519
+ if self.time_fusion == 'ada':
520
+ nn.init.constant_(self.time_ada_final.weight, 0)
521
+ nn.init.constant_(self.time_ada_final.bias, 0)
522
+ for block in self.in_blocks:
523
+ nn.init.constant_(block.adaln.time_ada.weight, 0)
524
+ nn.init.constant_(block.adaln.time_ada.bias, 0)
525
+ nn.init.constant_(self.mid_block.adaln.time_ada.weight, 0)
526
+ nn.init.constant_(self.mid_block.adaln.time_ada.bias, 0)
527
+ for block in self.out_blocks:
528
+ nn.init.constant_(block.adaln.time_ada.weight, 0)
529
+ nn.init.constant_(block.adaln.time_ada.bias, 0)
530
+ elif self.time_fusion == 'ada_single':
531
+ nn.init.constant_(self.time_ada.weight, 0)
532
+ nn.init.constant_(self.time_ada.bias, 0)
533
+ nn.init.constant_(self.time_ada_final.weight, 0)
534
+ nn.init.constant_(self.time_ada_final.bias, 0)
535
+ elif self.time_fusion in ['ada_sola', 'ada_sola_bias']:
536
+ nn.init.constant_(self.time_ada.weight, 0)
537
+ nn.init.constant_(self.time_ada.bias, 0)
538
+ nn.init.constant_(self.time_ada_final.weight, 0)
539
+ nn.init.constant_(self.time_ada_final.bias, 0)
540
+ for block in self.in_blocks:
541
+ nn.init.kaiming_uniform_(
542
+ block.adaln.lora_a.weight, a=math.sqrt(5)
543
+ )
544
+ nn.init.constant_(block.adaln.lora_b.weight, 0)
545
+ nn.init.kaiming_uniform_(
546
+ self.mid_block.adaln.lora_a.weight, a=math.sqrt(5)
547
+ )
548
+ nn.init.constant_(self.mid_block.adaln.lora_b.weight, 0)
549
+ for block in self.out_blocks:
550
+ nn.init.kaiming_uniform_(
551
+ block.adaln.lora_a.weight, a=math.sqrt(5)
552
+ )
553
+ nn.init.constant_(block.adaln.lora_b.weight, 0)
554
+
555
+ def initialize_weights(self):
556
+ # Basic init for all layers
557
+ def _basic_init(module):
558
+ if isinstance(module, nn.Linear):
559
+ nn.init.xavier_uniform_(module.weight)
560
+ if module.bias is not None:
561
+ nn.init.constant_(module.bias, 0)
562
+
563
+ self.apply(_basic_init)
564
+
565
+ # init patch Conv like Linear
566
+ w = self.patch_embed.proj.weight.data
567
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
568
+ nn.init.constant_(self.patch_embed.proj.bias, 0)
569
+
570
+ # Zero-out AdaLN
571
+ if self.use_adanorm:
572
+ self._init_ada()
573
+
574
+ # Zero-out Cross Attention
575
+ if self.context_cross:
576
+ for block in self.in_blocks:
577
+ nn.init.constant_(block.cross_attn.proj.weight, 0)
578
+ nn.init.constant_(block.cross_attn.proj.bias, 0)
579
+ nn.init.constant_(self.mid_block.cross_attn.proj.weight, 0)
580
+ nn.init.constant_(self.mid_block.cross_attn.proj.bias, 0)
581
+ for block in self.out_blocks:
582
+ nn.init.constant_(block.cross_attn.proj.weight, 0)
583
+ nn.init.constant_(block.cross_attn.proj.bias, 0)
584
+
585
+ # Zero-out cls embedding
586
+ if self.cls_embed:
587
+ if self.use_adanorm:
588
+ nn.init.constant_(self.cls_embed[-1].weight, 0)
589
+ nn.init.constant_(self.cls_embed[-1].bias, 0)
590
+
591
+ # Zero-out Output
592
+ # might not zero-out this when using v-prediction
593
+ # it could be good when using noise-prediction
594
+ # nn.init.constant_(self.final_block.linear.weight, 0)
595
+ # nn.init.constant_(self.final_block.linear.bias, 0)
596
+ # if self.use_conv:
597
+ # nn.init.constant_(self.final_block.final_layer.weight.data, 0)
598
+ # nn.init.constant_(self.final_block.final_layer.bias, 0)
599
+
600
+ # init out Conv
601
+ if self.use_conv:
602
+ nn.init.xavier_uniform_(self.final_block.final_layer.weight)
603
+ nn.init.constant_(self.final_block.final_layer.bias, 0)
604
+
605
+ def _concat_x_context(self, x, context, x_mask=None, context_mask=None):
606
+ assert context.shape[-2] == self.context_max_length
607
+ # Check if either x_mask or context_mask is provided
608
+ B = x.shape[0]
609
+ # Create default masks if they are not provided
610
+ if x_mask is None:
611
+ x_mask = torch.ones(B, x.shape[-2], device=x.device).bool()
612
+ if context_mask is None:
613
+ context_mask = torch.ones(
614
+ B, context.shape[-2], device=context.device
615
+ ).bool()
616
+ # Concatenate the masks along the second dimension (dim=1)
617
+ x_mask = torch.cat([context_mask, x_mask], dim=1)
618
+ # Concatenate context and x along the second dimension (dim=1)
619
+ x = torch.cat((context, x), dim=1)
620
+ return x, x_mask
621
+
622
+ def forward(
623
+ self,
624
+ x,
625
+ timesteps,
626
+ context,
627
+ x_mask=None,
628
+ context_mask=None,
629
+ cls_token=None,
630
+ controlnet_skips=None,
631
+ ):
632
+ # make it compatible with int time step during inference
633
+ if timesteps.dim() == 0:
634
+ timesteps = timesteps.expand(x.shape[0]
635
+ ).to(x.device, dtype=torch.long)
636
+
637
+ x = self.patch_embed(x)
638
+ x = self.x_pe(x)
639
+
640
+ B, L, D = x.shape
641
+
642
+ if self.use_context:
643
+ context_token = self.context_embed(context)
644
+ context_token = self.context_pe(context_token)
645
+ if self.context_fusion == 'concat' or self.context_fusion == 'joint':
646
+ x, x_mask = self._concat_x_context(
647
+ x=x,
648
+ context=context_token,
649
+ x_mask=x_mask,
650
+ context_mask=context_mask
651
+ )
652
+ context_token, context_mask = None, None
653
+ else:
654
+ context_token, context_mask = None, None
655
+
656
+ time_token = self.time_embed(timesteps)
657
+ if self.cls_embed:
658
+ cls_token = self.cls_embed(cls_token)
659
+ time_ada = None
660
+ time_ada_final = None
661
+ if self.use_adanorm:
662
+ if self.cls_embed:
663
+ time_token = time_token + cls_token
664
+ time_token = self.time_act(time_token)
665
+ time_ada_final = self.time_ada_final(time_token)
666
+ if self.time_ada is not None:
667
+ time_ada = self.time_ada(time_token)
668
+ else:
669
+ time_token = time_token.unsqueeze(dim=1)
670
+ if self.cls_embed:
671
+ cls_token = cls_token.unsqueeze(dim=1)
672
+ time_token = torch.cat([time_token, cls_token], dim=1)
673
+ time_token = self.time_pe(time_token)
674
+ x = torch.cat((time_token, x), dim=1)
675
+ if x_mask is not None:
676
+ x_mask = torch.cat([
677
+ torch.ones(B, time_token.shape[1],
678
+ device=x_mask.device).bool(), x_mask
679
+ ],
680
+ dim=1)
681
+ time_token = None
682
+
683
+ skips = []
684
+ for blk in self.in_blocks:
685
+ x = blk(
686
+ x=x,
687
+ time_token=time_token,
688
+ time_ada=time_ada,
689
+ skip=None,
690
+ context=context_token,
691
+ x_mask=x_mask,
692
+ context_mask=context_mask,
693
+ extras=self.extras
694
+ )
695
+ if self.use_skip:
696
+ skips.append(x)
697
+
698
+ x = self.mid_block(
699
+ x=x,
700
+ time_token=time_token,
701
+ time_ada=time_ada,
702
+ skip=None,
703
+ context=context_token,
704
+ x_mask=x_mask,
705
+ context_mask=context_mask,
706
+ extras=self.extras
707
+ )
708
+ for blk in self.out_blocks:
709
+ if self.use_skip:
710
+ skip = skips.pop()
711
+ if controlnet_skips:
712
+ # add to skip like u-net controlnet
713
+ skip = skip + controlnet_skips.pop()
714
+ else:
715
+ skip = None
716
+ if controlnet_skips:
717
+ # directly add to x
718
+ x = x + controlnet_skips.pop()
719
+
720
+ x = blk(
721
+ x=x,
722
+ time_token=time_token,
723
+ time_ada=time_ada,
724
+ skip=skip,
725
+ context=context_token,
726
+ x_mask=x_mask,
727
+ context_mask=context_mask,
728
+ extras=self.extras
729
+ )
730
+
731
+ x = self.final_block(x, time_ada=time_ada_final, extras=self.extras)
732
+
733
+ return x
734
+
735
+
736
+ class MaskDiT(nn.Module):
737
+ def __init__(
738
+ self,
739
+ model: UDiT,
740
+ mae=False,
741
+ mae_prob=0.5,
742
+ mask_ratio=[0.25, 1.0],
743
+ mask_span=10,
744
+ ):
745
+ super().__init__()
746
+ self.model = model
747
+ self.mae = mae
748
+ if self.mae:
749
+ out_channel = model.out_chans
750
+ self.mask_embed = nn.Parameter(torch.zeros((out_channel)))
751
+ self.mae_prob = mae_prob
752
+ self.mask_ratio = mask_ratio
753
+ self.mask_span = mask_span
754
+
755
+ def random_masking(self, gt, mask_ratios, mae_mask_infer=None):
756
+ B, D, L = gt.shape
757
+ if mae_mask_infer is None:
758
+ # mask = torch.rand(B, L).to(gt.device) < mask_ratios.unsqueeze(1)
759
+ mask_ratios = mask_ratios.cpu().numpy()
760
+ mask = compute_mask_indices(
761
+ shape=[B, L],
762
+ padding_mask=None,
763
+ mask_prob=mask_ratios,
764
+ mask_length=self.mask_span,
765
+ mask_type="static",
766
+ mask_other=0.0,
767
+ min_masks=1,
768
+ no_overlap=False,
769
+ min_space=0,
770
+ )
771
+ mask = mask.unsqueeze(1).expand_as(gt)
772
+ else:
773
+ mask = mae_mask_infer
774
+ mask = mask.expand_as(gt)
775
+ gt[mask] = self.mask_embed.view(1, D, 1).expand_as(gt)[mask]
776
+ return gt, mask.type_as(gt)
777
+
778
+ def forward(
779
+ self,
780
+ x,
781
+ timesteps,
782
+ context,
783
+ x_mask=None,
784
+ context_mask=None,
785
+ cls_token=None,
786
+ gt=None,
787
+ mae_mask_infer=None,
788
+ forward_model=True
789
+ ):
790
+ # todo: handle controlnet inside
791
+ mae_mask = torch.ones_like(x)
792
+ if self.mae:
793
+ if gt is not None:
794
+ B, D, L = gt.shape
795
+ mask_ratios = torch.FloatTensor(B).uniform_(*self.mask_ratio
796
+ ).to(gt.device)
797
+ gt, mae_mask = self.random_masking(
798
+ gt, mask_ratios, mae_mask_infer
799
+ )
800
+ # apply mae only to the selected batches
801
+ if mae_mask_infer is None:
802
+ # determine mae batch
803
+ mae_batch = torch.rand(B) < self.mae_prob
804
+ gt[~mae_batch] = self.mask_embed.view(
805
+ 1, D, 1
806
+ ).expand_as(gt)[~mae_batch]
807
+ mae_mask[~mae_batch] = 1.0
808
+ else:
809
+ B, D, L = x.shape
810
+ gt = self.mask_embed.view(1, D, 1).expand_as(x)
811
+ x = torch.cat([x, gt, mae_mask[:, 0:1, :]], dim=1)
812
+
813
+ if forward_model:
814
+ x = self.model(
815
+ x=x,
816
+ timesteps=timesteps,
817
+ context=context,
818
+ x_mask=x_mask,
819
+ context_mask=context_mask,
820
+ cls_token=cls_token
821
+ )
822
+ # logger.info(mae_mask[:, 0, :].sum(dim=-1))
823
+ return x, mae_mask
models/dit/modules.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.utils.checkpoint
6
+ from torch.cuda.amp import autocast
7
+ import math
8
+ import einops
9
+ from einops import rearrange, repeat
10
+ from inspect import isfunction
11
+
12
+
13
+ def trunc_normal_(tensor, mean, std, a, b):
14
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
15
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
16
+ def norm_cdf(x):
17
+ # Computes standard normal cumulative distribution function
18
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
19
+
20
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
21
+ warnings.warn(
22
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
23
+ "The distribution of values may be incorrect.",
24
+ stacklevel=2
25
+ )
26
+
27
+ with torch.no_grad():
28
+ # Values are generated by using a truncated uniform distribution and
29
+ # then using the inverse CDF for the normal distribution.
30
+ # Get upper and lower cdf values
31
+ l = norm_cdf((a - mean) / std)
32
+ u = norm_cdf((b - mean) / std)
33
+
34
+ # Uniformly fill tensor with values from [l, u], then translate to
35
+ # [2l-1, 2u-1].
36
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
37
+
38
+ # Use inverse cdf transform for normal distribution to get truncated
39
+ # standard normal
40
+ tensor.erfinv_()
41
+
42
+ # Transform to proper mean, std
43
+ tensor.mul_(std * math.sqrt(2.))
44
+ tensor.add_(mean)
45
+
46
+ # Clamp to ensure it's in the proper range
47
+ tensor.clamp_(min=a, max=b)
48
+ return tensor
49
+
50
+
51
+ # disable in checkpoint mode
52
+ # @torch.jit.script
53
+ def film_modulate(x, shift, scale):
54
+ return x * (1 + scale) + shift
55
+
56
+
57
+ def timestep_embedding(timesteps, dim, max_period=10000):
58
+ """
59
+ Create sinusoidal timestep embeddings.
60
+
61
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
62
+ These may be fractional.
63
+ :param dim: the dimension of the output.
64
+ :param max_period: controls the minimum frequency of the embeddings.
65
+ :return: an [N x dim] Tensor of positional embeddings.
66
+ """
67
+ half = dim // 2
68
+ freqs = torch.exp(
69
+ -math.log(max_period) *
70
+ torch.arange(start=0, end=half, dtype=torch.float32) / half
71
+ ).to(device=timesteps.device)
72
+ args = timesteps[:, None].float() * freqs[None]
73
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
74
+ if dim % 2:
75
+ embedding = torch.cat([embedding,
76
+ torch.zeros_like(embedding[:, :1])],
77
+ dim=-1)
78
+ return embedding
79
+
80
+
81
+ class TimestepEmbedder(nn.Module):
82
+ """
83
+ Embeds scalar timesteps into vector representations.
84
+ """
85
+ def __init__(
86
+ self, hidden_size, frequency_embedding_size=256, out_size=None
87
+ ):
88
+ super().__init__()
89
+ if out_size is None:
90
+ out_size = hidden_size
91
+ self.mlp = nn.Sequential(
92
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
93
+ nn.SiLU(),
94
+ nn.Linear(hidden_size, out_size, bias=True),
95
+ )
96
+ self.frequency_embedding_size = frequency_embedding_size
97
+
98
+ def forward(self, t):
99
+ t_freq = timestep_embedding(t, self.frequency_embedding_size).type(
100
+ self.mlp[0].weight.dtype
101
+ )
102
+ t_emb = self.mlp(t_freq)
103
+ return t_emb
104
+
105
+
106
+ def patchify(imgs, patch_size, input_type='2d'):
107
+ if input_type == '2d':
108
+ x = einops.rearrange(
109
+ imgs,
110
+ 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)',
111
+ p1=patch_size,
112
+ p2=patch_size
113
+ )
114
+ elif input_type == '1d':
115
+ x = einops.rearrange(imgs, 'B C (h p1) -> B h (p1 C)', p1=patch_size)
116
+ return x
117
+
118
+
119
+ def unpatchify(x, channels=3, input_type='2d', img_size=None):
120
+ if input_type == '2d':
121
+ patch_size = int((x.shape[2] // channels)**0.5)
122
+ # h = w = int(x.shape[1] ** .5)
123
+ h, w = img_size[0] // patch_size, img_size[1] // patch_size
124
+ assert h * w == x.shape[1] and patch_size**2 * channels == x.shape[2]
125
+ x = einops.rearrange(
126
+ x,
127
+ 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)',
128
+ h=h,
129
+ p1=patch_size,
130
+ p2=patch_size
131
+ )
132
+ elif input_type == '1d':
133
+ patch_size = int((x.shape[2] // channels))
134
+ h = x.shape[1]
135
+ assert patch_size * channels == x.shape[2]
136
+ x = einops.rearrange(x, 'B h (p1 C) -> B C (h p1)', h=h, p1=patch_size)
137
+ return x
138
+
139
+
140
+ class PatchEmbed(nn.Module):
141
+ """
142
+ Image to Patch Embedding
143
+ """
144
+ def __init__(self, patch_size, in_chans=3, embed_dim=768, input_type='2d'):
145
+ super().__init__()
146
+ self.patch_size = patch_size
147
+ self.input_type = input_type
148
+ if input_type == '2d':
149
+ self.proj = nn.Conv2d(
150
+ in_chans,
151
+ embed_dim,
152
+ kernel_size=patch_size,
153
+ stride=patch_size,
154
+ bias=True
155
+ )
156
+ elif input_type == '1d':
157
+ self.proj = nn.Conv1d(
158
+ in_chans,
159
+ embed_dim,
160
+ kernel_size=patch_size,
161
+ stride=patch_size,
162
+ bias=True
163
+ )
164
+
165
+ def forward(self, x):
166
+ if self.input_type == '2d':
167
+ B, C, H, W = x.shape
168
+ assert H % self.patch_size == 0 and W % self.patch_size == 0
169
+ elif self.input_type == '1d':
170
+ B, C, H = x.shape
171
+ assert H % self.patch_size == 0
172
+
173
+ x = self.proj(x).flatten(2).transpose(1, 2)
174
+ return x
175
+
176
+
177
+ class PositionalConvEmbedding(nn.Module):
178
+ """
179
+ Convolutional positional embedding used in F5-TTS.
180
+ """
181
+ def __init__(self, dim=768, kernel_size=31, groups=16):
182
+ super().__init__()
183
+ assert kernel_size % 2 != 0
184
+ self.conv1d = nn.Sequential(
185
+ nn.Conv1d(
186
+ dim, dim, kernel_size, groups=groups, padding=kernel_size // 2
187
+ ),
188
+ nn.Mish(),
189
+ nn.Conv1d(
190
+ dim, dim, kernel_size, groups=groups, padding=kernel_size // 2
191
+ ),
192
+ nn.Mish(),
193
+ )
194
+
195
+ def forward(self, x):
196
+ # B T C
197
+ x = self.conv1d(x.transpose(1, 2))
198
+ x = x.transpose(1, 2)
199
+ return x
200
+
201
+
202
+ class SinusoidalPositionalEncoding(nn.Module):
203
+ def __init__(self, dim, length):
204
+ super(SinusoidalPositionalEncoding, self).__init__()
205
+ self.length = length
206
+ self.dim = dim
207
+ self.register_buffer(
208
+ 'pe', self._generate_positional_encoding(length, dim)
209
+ )
210
+
211
+ def _generate_positional_encoding(self, length, dim):
212
+ pe = torch.zeros(length, dim)
213
+ position = torch.arange(0, length, dtype=torch.float).unsqueeze(1)
214
+ div_term = torch.exp(
215
+ torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim)
216
+ )
217
+
218
+ pe[:, 0::2] = torch.sin(position * div_term)
219
+ pe[:, 1::2] = torch.cos(position * div_term)
220
+
221
+ pe = pe.unsqueeze(0)
222
+ return pe
223
+
224
+ def forward(self, x):
225
+ x = x + self.pe[:, :x.size(1)]
226
+ return x
227
+
228
+
229
+ class PE_wrapper(nn.Module):
230
+ def __init__(self, dim=768, method='abs', length=None, **kwargs):
231
+ super().__init__()
232
+ self.method = method
233
+ if method == 'abs':
234
+ # init absolute pe like UViT
235
+ self.length = length
236
+ self.abs_pe = nn.Parameter(torch.zeros(1, length, dim))
237
+ trunc_normal_(self.abs_pe, mean=0.0, std=.02, a=-.04, b=.04)
238
+ elif method == 'conv':
239
+ self.conv_pe = PositionalConvEmbedding(dim=dim, **kwargs)
240
+ elif method == 'sinu':
241
+ self.sinu_pe = SinusoidalPositionalEncoding(dim=dim, length=length)
242
+ elif method == 'none':
243
+ # skip pe
244
+ self.id = nn.Identity()
245
+ else:
246
+ raise NotImplementedError
247
+
248
+ def forward(self, x):
249
+ if self.method == 'abs':
250
+ _, L, _ = x.shape
251
+ assert L <= self.length
252
+ x = x + self.abs_pe[:, :L, :]
253
+ elif self.method == 'conv':
254
+ x = x + self.conv_pe(x)
255
+ elif self.method == 'sinu':
256
+ x = self.sinu_pe(x)
257
+ elif self.method == 'none':
258
+ x = self.id(x)
259
+ else:
260
+ raise NotImplementedError
261
+ return x
262
+
263
+
264
+ class RMSNorm(torch.nn.Module):
265
+ def __init__(self, dim: int, eps: float = 1e-6):
266
+ """
267
+ Initialize the RMSNorm normalization layer.
268
+
269
+ Args:
270
+ dim (int): The dimension of the input tensor.
271
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
272
+
273
+ Attributes:
274
+ eps (float): A small value added to the denominator for numerical stability.
275
+ weight (nn.Parameter): Learnable scaling parameter.
276
+
277
+ """
278
+ super().__init__()
279
+ self.eps = eps
280
+ self.weight = nn.Parameter(torch.ones(dim))
281
+
282
+ def _norm(self, x):
283
+ """
284
+ Apply the RMSNorm normalization to the input tensor.
285
+
286
+ Args:
287
+ x (torch.Tensor): The input tensor.
288
+
289
+ Returns:
290
+ torch.Tensor: The normalized tensor.
291
+
292
+ """
293
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
294
+
295
+ def forward(self, x):
296
+ """
297
+ Forward pass through the RMSNorm layer.
298
+
299
+ Args:
300
+ x (torch.Tensor): The input tensor.
301
+
302
+ Returns:
303
+ torch.Tensor: The output tensor after applying RMSNorm.
304
+
305
+ """
306
+ output = self._norm(x.float()).type_as(x)
307
+ return output * self.weight
308
+
309
+
310
+ class GELU(nn.Module):
311
+ def __init__(
312
+ self,
313
+ dim_in: int,
314
+ dim_out: int,
315
+ approximate: str = "none",
316
+ bias: bool = True
317
+ ):
318
+ super().__init__()
319
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
320
+ self.approximate = approximate
321
+
322
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
323
+ if gate.device.type != "mps":
324
+ return F.gelu(gate, approximate=self.approximate)
325
+ # mps: gelu is not implemented for float16
326
+ return F.gelu(
327
+ gate.to(dtype=torch.float32), approximate=self.approximate
328
+ ).to(dtype=gate.dtype)
329
+
330
+ def forward(self, hidden_states):
331
+ hidden_states = self.proj(hidden_states)
332
+ hidden_states = self.gelu(hidden_states)
333
+ return hidden_states
334
+
335
+
336
+ class GEGLU(nn.Module):
337
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
338
+ super().__init__()
339
+ self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
340
+
341
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
342
+ if gate.device.type != "mps":
343
+ return F.gelu(gate)
344
+ # mps: gelu is not implemented for float16
345
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
346
+
347
+ def forward(self, hidden_states):
348
+ hidden_states = self.proj(hidden_states)
349
+ hidden_states, gate = hidden_states.chunk(2, dim=-1)
350
+ return hidden_states * self.gelu(gate)
351
+
352
+
353
+ class ApproximateGELU(nn.Module):
354
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
355
+ super().__init__()
356
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
357
+
358
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
359
+ x = self.proj(x)
360
+ return x * torch.sigmoid(1.702 * x)
361
+
362
+
363
+ # disable in checkpoint mode
364
+ # @torch.jit.script
365
+ def snake_beta(x, alpha, beta):
366
+ return x + beta * torch.sin(x * alpha).pow(2)
367
+
368
+
369
+ class Snake(nn.Module):
370
+ def __init__(self, dim_in, dim_out, bias, alpha_trainable=True):
371
+ super().__init__()
372
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
373
+ self.alpha = nn.Parameter(torch.ones(1, 1, dim_out))
374
+ self.beta = nn.Parameter(torch.ones(1, 1, dim_out))
375
+ self.alpha.requires_grad = alpha_trainable
376
+ self.beta.requires_grad = alpha_trainable
377
+
378
+ def forward(self, x):
379
+ x = self.proj(x)
380
+ x = snake_beta(x, self.alpha, self.beta)
381
+ return x
382
+
383
+
384
+ class GESnake(nn.Module):
385
+ def __init__(self, dim_in, dim_out, bias, alpha_trainable=True):
386
+ super().__init__()
387
+ self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
388
+ self.alpha = nn.Parameter(torch.ones(1, 1, dim_out))
389
+ self.beta = nn.Parameter(torch.ones(1, 1, dim_out))
390
+ self.alpha.requires_grad = alpha_trainable
391
+ self.beta.requires_grad = alpha_trainable
392
+
393
+ def forward(self, x):
394
+ x = self.proj(x)
395
+ x, gate = x.chunk(2, dim=-1)
396
+ return x * snake_beta(gate, self.alpha, self.beta)
397
+
398
+
399
+ class FeedForward(nn.Module):
400
+ def __init__(
401
+ self,
402
+ dim,
403
+ dim_out=None,
404
+ mult=4,
405
+ dropout=0.0,
406
+ activation_fn="geglu",
407
+ final_dropout=False,
408
+ inner_dim=None,
409
+ bias=True,
410
+ ):
411
+ super().__init__()
412
+ if inner_dim is None:
413
+ inner_dim = int(dim * mult)
414
+ dim_out = dim_out if dim_out is not None else dim
415
+
416
+ if activation_fn == "gelu":
417
+ act_fn = GELU(dim, inner_dim, bias=bias)
418
+ elif activation_fn == "gelu-approximate":
419
+ act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
420
+ elif activation_fn == "geglu":
421
+ act_fn = GEGLU(dim, inner_dim, bias=bias)
422
+ elif activation_fn == "geglu-approximate":
423
+ act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
424
+ elif activation_fn == "snake":
425
+ act_fn = Snake(dim, inner_dim, bias=bias)
426
+ elif activation_fn == "gesnake":
427
+ act_fn = GESnake(dim, inner_dim, bias=bias)
428
+ else:
429
+ raise NotImplementedError
430
+
431
+ self.net = nn.ModuleList([])
432
+ # project in
433
+ self.net.append(act_fn)
434
+ # project dropout
435
+ self.net.append(nn.Dropout(dropout))
436
+ # project out
437
+ self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
438
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
439
+ if final_dropout:
440
+ self.net.append(nn.Dropout(dropout))
441
+
442
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
443
+ for module in self.net:
444
+ hidden_states = module(hidden_states)
445
+ return hidden_states
models/dit/rotary.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ "this rope is faster than llama rope with jit script"
3
+
4
+
5
+ def rotate_half(x):
6
+ x1, x2 = x.chunk(2, dim=-1)
7
+ return torch.cat((-x2, x1), dim=-1)
8
+
9
+
10
+ # disable in checkpoint mode
11
+ # @torch.jit.script
12
+ def apply_rotary_pos_emb(x, cos, sin):
13
+ # NOTE: This could probably be moved to Triton
14
+ # Handle a possible sequence length mismatch in between q and k
15
+ cos = cos[:, :, :x.shape[-2], :]
16
+ sin = sin[:, :, :x.shape[-2], :]
17
+ return (x*cos) + (rotate_half(x) * sin)
18
+
19
+
20
+ class RotaryEmbedding(torch.nn.Module):
21
+ """
22
+ The rotary position embeddings from RoFormer_ (Su et. al).
23
+ A crucial insight from the method is that the query and keys are
24
+ transformed by rotation matrices which depend on the relative positions.
25
+
26
+ Other implementations are available in the Rotary Transformer repo_ and in
27
+ GPT-NeoX_, GPT-NeoX was an inspiration
28
+
29
+ .. _RoFormer: https://arxiv.org/abs/2104.09864
30
+ .. _repo: https://github.com/ZhuiyiTechnology/roformer
31
+ .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
32
+
33
+
34
+ .. warning: Please note that this embedding is not registered on purpose, as it is transformative
35
+ (it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis
36
+ """
37
+ def __init__(self, dim: int):
38
+ super().__init__()
39
+ # Generate and save the inverse frequency buffer (non trainable)
40
+ inv_freq = 1.0 / (10000**(torch.arange(0, dim, 2).float() / dim))
41
+ self.register_buffer("inv_freq", inv_freq)
42
+ self._seq_len_cached = None
43
+ self._cos_cached = None
44
+ self._sin_cached = None
45
+
46
+ def _update_cos_sin_tables(self, x, seq_dimension=-2):
47
+ # expect input: B, H, L, D
48
+ seq_len = x.shape[seq_dimension]
49
+
50
+ # Reset the tables if the sequence length has changed,
51
+ # or if we're on a new device (possibly due to tracing for instance)
52
+ # also make sure dtype wont change
53
+ if (
54
+ seq_len != self._seq_len_cached or
55
+ self._cos_cached.device != x.device or
56
+ self._cos_cached.dtype != x.dtype
57
+ ):
58
+ self._seq_len_cached = seq_len
59
+ t = torch.arange(
60
+ x.shape[seq_dimension], device=x.device, dtype=torch.float32
61
+ )
62
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype))
63
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
64
+
65
+ self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype)
66
+ self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype)
67
+
68
+ return self._cos_cached, self._sin_cached
69
+
70
+ def forward(self, q, k):
71
+ self._cos_cached, self._sin_cached = self._update_cos_sin_tables(
72
+ q.float(), seq_dimension=-2
73
+ )
74
+ if k is not None:
75
+ return (
76
+ apply_rotary_pos_emb(
77
+ q.float(), self._cos_cached, self._sin_cached
78
+ ).type_as(q),
79
+ apply_rotary_pos_emb(
80
+ k.float(), self._cos_cached, self._sin_cached
81
+ ).type_as(k),
82
+ )
83
+ else:
84
+ return (
85
+ apply_rotary_pos_emb(
86
+ q.float(), self._cos_cached, self._sin_cached
87
+ ).type_as(q), None
88
+ )
models/dit/span_mask.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from typing import Optional, Tuple
4
+
5
+
6
+ def compute_mask_indices(
7
+ shape: Tuple[int, int],
8
+ padding_mask: Optional[torch.Tensor],
9
+ mask_prob: float,
10
+ mask_length: int,
11
+ mask_type: str = "static",
12
+ mask_other: float = 0.0,
13
+ min_masks: int = 0,
14
+ no_overlap: bool = False,
15
+ min_space: int = 0,
16
+ ) -> np.ndarray:
17
+ """
18
+ Computes random mask spans for a given shape
19
+
20
+ Args:
21
+ shape: the the shape for which to compute masks.
22
+ should be of size 2 where first element is batch size and 2nd is timesteps
23
+ padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
24
+ mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
25
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
26
+ however due to overlaps, the actual number will be smaller (unless no_overlap is True)
27
+ mask_type: how to compute mask lengths
28
+ static = fixed size
29
+ uniform = sample from uniform distribution [mask_other, mask_length*2]
30
+ normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
31
+ poisson = sample from possion distribution with lambda = mask length
32
+ min_masks: minimum number of masked spans
33
+ no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
34
+ min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
35
+ """
36
+
37
+ bsz, all_sz = shape
38
+ mask = np.full((bsz, all_sz), False)
39
+
40
+ # Convert mask_prob to a NumPy array
41
+ mask_prob = np.array(mask_prob)
42
+
43
+ # Calculate all_num_mask for each element in the batch
44
+ all_num_mask = np.floor(
45
+ mask_prob * all_sz / float(mask_length) + np.random.rand(bsz)
46
+ ).astype(int)
47
+
48
+ # Apply the max operation with min_masks for each element
49
+ all_num_mask = np.maximum(min_masks, all_num_mask)
50
+
51
+ mask_idcs = []
52
+ for i in range(bsz):
53
+ if padding_mask is not None:
54
+ sz = all_sz - padding_mask[i].long().sum().item()
55
+ num_mask = int(
56
+ # add a random number for probabilistic rounding
57
+ mask_prob * sz / float(mask_length) + np.random.rand()
58
+ )
59
+ num_mask = max(min_masks, num_mask)
60
+ else:
61
+ sz = all_sz
62
+ num_mask = all_num_mask[i]
63
+
64
+ if mask_type == "static":
65
+ lengths = np.full(num_mask, mask_length)
66
+ elif mask_type == "uniform":
67
+ lengths = np.random.randint(
68
+ mask_other, mask_length*2 + 1, size=num_mask
69
+ )
70
+ elif mask_type == "normal":
71
+ lengths = np.random.normal(mask_length, mask_other, size=num_mask)
72
+ lengths = [max(1, int(round(x))) for x in lengths]
73
+ elif mask_type == "poisson":
74
+ lengths = np.random.poisson(mask_length, size=num_mask)
75
+ lengths = [int(round(x)) for x in lengths]
76
+ else:
77
+ raise Exception("unknown mask selection " + mask_type)
78
+
79
+ if sum(lengths) == 0:
80
+ lengths[0] = min(mask_length, sz - 1)
81
+
82
+ if no_overlap:
83
+ mask_idc = []
84
+
85
+ def arrange(s, e, length, keep_length):
86
+ span_start = np.random.randint(s, e - length)
87
+ mask_idc.extend(span_start + i for i in range(length))
88
+
89
+ new_parts = []
90
+ if span_start - s - min_space >= keep_length:
91
+ new_parts.append((s, span_start - min_space + 1))
92
+ if e - span_start - keep_length - min_space > keep_length:
93
+ new_parts.append((span_start + length + min_space, e))
94
+ return new_parts
95
+
96
+ parts = [(0, sz)]
97
+ min_length = min(lengths)
98
+ for length in sorted(lengths, reverse=True):
99
+ lens = np.fromiter(
100
+ (
101
+ e - s if e - s >= length + min_space else 0
102
+ for s, e in parts
103
+ ),
104
+ np.int,
105
+ )
106
+ l_sum = np.sum(lens)
107
+ if l_sum == 0:
108
+ break
109
+ probs = lens / np.sum(lens)
110
+ c = np.random.choice(len(parts), p=probs)
111
+ s, e = parts.pop(c)
112
+ parts.extend(arrange(s, e, length, min_length))
113
+ mask_idc = np.asarray(mask_idc)
114
+ else:
115
+ min_len = min(lengths)
116
+ if sz - min_len <= num_mask:
117
+ min_len = sz - num_mask - 1
118
+
119
+ mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
120
+
121
+ mask_idc = np.asarray([
122
+ mask_idc[j] + offset for j in range(len(mask_idc))
123
+ for offset in range(lengths[j])
124
+ ])
125
+
126
+ mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
127
+ # min_len = min([len(m) for m in mask_idcs])
128
+ for i, mask_idc in enumerate(mask_idcs):
129
+ # if len(mask_idc) > min_len:
130
+ # mask_idc = np.random.choice(mask_idc, min_len, replace=False)
131
+ mask[i, mask_idc] = True
132
+
133
+ return torch.tensor(mask)
134
+
135
+
136
+ if __name__ == '__main__':
137
+ mask = compute_mask_indices(
138
+ shape=[4, 500],
139
+ padding_mask=None,
140
+ mask_prob=[0.65, 0.5, 0.65, 0.65],
141
+ mask_length=10,
142
+ mask_type="static",
143
+ mask_other=0.0,
144
+ min_masks=1,
145
+ no_overlap=False,
146
+ min_space=0,
147
+ )
148
+ print(mask)
149
+ print(mask.sum(dim=1))
models/flow_matching.py ADDED
@@ -0,0 +1,1082 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Optional, Union, List, Sequence
2
+
3
+ import inspect
4
+ import random
5
+
6
+ from tqdm import tqdm
7
+ import numpy as np
8
+ import copy
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from diffusers.utils.torch_utils import randn_tensor
14
+ from diffusers import FlowMatchEulerDiscreteScheduler
15
+ from diffusers.training_utils import compute_density_for_timestep_sampling
16
+
17
+ from models.autoencoder.autoencoder_base import AutoEncoderBase
18
+ from models.content_encoder.content_encoder import ContentEncoder
19
+ from models.content_adapter import ContentAdapterBase
20
+ from models.common import LoadPretrainedBase, CountParamsBase, SaveTrainableParamsBase
21
+ from utils.torch_utilities import (
22
+ create_alignment_path, create_mask_from_length, loss_with_mask,
23
+ trim_or_pad_length
24
+ )
25
+ from constants import SAME_LENGTH_TASKS
26
+
27
+
28
+ class FlowMatchingMixin:
29
+ def __init__(
30
+ self,
31
+ cfg_drop_ratio: float = 0.2,
32
+ sample_strategy: str = 'normal',
33
+ num_train_steps: int = 1000
34
+ ) -> None:
35
+ r"""
36
+ Args:
37
+ cfg_drop_ratio (float): Dropout ratio for the autoencoder.
38
+ sample_strategy (str): Sampling strategy for timesteps during training.
39
+ num_train_steps (int): Number of training steps for the noise scheduler.
40
+ """
41
+ self.sample_strategy = sample_strategy
42
+ self.infer_noise_scheduler = FlowMatchEulerDiscreteScheduler(
43
+ num_train_timesteps=num_train_steps
44
+ )
45
+ self.train_noise_scheduler = copy.deepcopy(self.infer_noise_scheduler)
46
+
47
+ self.classifier_free_guidance = cfg_drop_ratio > 0.0
48
+ self.cfg_drop_ratio = cfg_drop_ratio
49
+
50
+ def get_input_target_and_timesteps(
51
+ self,
52
+ latent: torch.Tensor,
53
+ training: bool,
54
+ ):
55
+ batch_size = latent.shape[0]
56
+ noise = torch.randn_like(latent)
57
+
58
+ if training:
59
+ if self.sample_strategy == 'normal':
60
+ u = compute_density_for_timestep_sampling(
61
+ weighting_scheme="logit_normal",
62
+ batch_size=batch_size,
63
+ logit_mean=0,
64
+ logit_std=1,
65
+ mode_scale=None,
66
+ )
67
+ elif self.sample_strategy == 'uniform':
68
+ u = torch.rand(batch_size, )
69
+ else:
70
+ raise NotImplementedError(
71
+ f"{self.sample_strategy} samlping for timesteps is not supported now"
72
+ )
73
+
74
+ indices = (
75
+ u * self.train_noise_scheduler.config.num_train_timesteps
76
+ ).long()
77
+ else:
78
+ indices = (
79
+ self.train_noise_scheduler.config.num_train_timesteps // 2
80
+ ) * torch.ones((batch_size, )).long()
81
+
82
+ # train_noise_scheduler.timesteps: a list from 1 ~ num_trainsteps with 1 as interval
83
+ timesteps = self.train_noise_scheduler.timesteps[indices].to(
84
+ device=latent.device
85
+ )
86
+ sigmas = self.get_sigmas(
87
+ timesteps, n_dim=latent.ndim, dtype=latent.dtype
88
+ )
89
+
90
+ noisy_latent = (1.0 - sigmas) * latent + sigmas * noise
91
+
92
+ target = noise - latent
93
+
94
+ return noisy_latent, target, timesteps
95
+
96
+ def get_sigmas(self, timesteps, n_dim=3, dtype=torch.float32):
97
+ device = timesteps.device
98
+
99
+ # a list from 1 declining to 1/num_train_steps
100
+ sigmas = self.train_noise_scheduler.sigmas.to(
101
+ device=device, dtype=dtype
102
+ )
103
+
104
+ schedule_timesteps = self.train_noise_scheduler.timesteps.to(device)
105
+ timesteps = timesteps.to(device)
106
+ step_indices = [(schedule_timesteps == t).nonzero().item()
107
+ for t in timesteps]
108
+
109
+ sigma = sigmas[step_indices].flatten()
110
+ while len(sigma.shape) < n_dim:
111
+ sigma = sigma.unsqueeze(-1)
112
+ return sigma
113
+
114
+ def retrieve_timesteps(
115
+ self,
116
+ num_inference_steps: Optional[int] = None,
117
+ device: Optional[Union[str, torch.device]] = None,
118
+ timesteps: Optional[List[int]] = None,
119
+ sigmas: Optional[List[float]] = None,
120
+ **kwargs,
121
+ ):
122
+ # used in inference, retrieve new timesteps on given inference timesteps
123
+ scheduler = self.infer_noise_scheduler
124
+
125
+ if timesteps is not None and sigmas is not None:
126
+ raise ValueError(
127
+ "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
128
+ )
129
+ if timesteps is not None:
130
+ accepts_timesteps = "timesteps" in set(
131
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
132
+ )
133
+ if not accepts_timesteps:
134
+ raise ValueError(
135
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
136
+ f" timestep schedules. Please check whether you are using the correct scheduler."
137
+ )
138
+ scheduler.set_timesteps(
139
+ timesteps=timesteps, device=device, **kwargs
140
+ )
141
+ timesteps = scheduler.timesteps
142
+ num_inference_steps = len(timesteps)
143
+ elif sigmas is not None:
144
+ accept_sigmas = "sigmas" in set(
145
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
146
+ )
147
+ if not accept_sigmas:
148
+ raise ValueError(
149
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
150
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
151
+ )
152
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
153
+ timesteps = scheduler.timesteps
154
+ num_inference_steps = len(timesteps)
155
+ else:
156
+ scheduler.set_timesteps(
157
+ num_inference_steps, device=device, **kwargs
158
+ )
159
+ timesteps = scheduler.timesteps
160
+ return timesteps, num_inference_steps
161
+
162
+
163
+ class ContentEncoderAdapterMixin:
164
+ def __init__(
165
+ self,
166
+ content_encoder: ContentEncoder,
167
+ content_adapter: ContentAdapterBase | None = None
168
+ ):
169
+ self.content_encoder = content_encoder
170
+ self.content_adapter = content_adapter
171
+
172
+ def encode_content(
173
+ self,
174
+ content: list[Any],
175
+ task: list[str],
176
+ device: str | torch.device,
177
+ instruction: torch.Tensor | None = None,
178
+ instruction_lengths: torch.Tensor | None = None
179
+ ):
180
+ content_output: dict[
181
+ str, torch.Tensor] = self.content_encoder.encode_content(
182
+ content, task, device=device
183
+ )
184
+ content, content_mask = content_output["content"], content_output[
185
+ "content_mask"]
186
+
187
+ if instruction is not None:
188
+ instruction_mask = create_mask_from_length(instruction_lengths)
189
+ (
190
+ content,
191
+ content_mask,
192
+ global_duration_pred,
193
+ local_duration_pred,
194
+ ) = self.content_adapter(
195
+ content, content_mask, instruction, instruction_mask
196
+ )
197
+
198
+ return_dict = {
199
+ "content": content,
200
+ "content_mask": content_mask,
201
+ "length_aligned_content": content_output["length_aligned_content"],
202
+ }
203
+ if instruction is not None:
204
+ return_dict["global_duration_pred"] = global_duration_pred
205
+ return_dict["local_duration_pred"] = local_duration_pred
206
+
207
+ return return_dict
208
+
209
+
210
+ class SingleTaskCrossAttentionAudioFlowMatching(
211
+ LoadPretrainedBase, CountParamsBase, SaveTrainableParamsBase,
212
+ FlowMatchingMixin, ContentEncoderAdapterMixin
213
+ ):
214
+ def __init__(
215
+ self,
216
+ autoencoder: nn.Module,
217
+ content_encoder: ContentEncoder,
218
+ backbone: nn.Module,
219
+ cfg_drop_ratio: float = 0.2,
220
+ sample_strategy: str = 'normal',
221
+ num_train_steps: int = 1000,
222
+ ):
223
+ nn.Module.__init__(self)
224
+ FlowMatchingMixin.__init__(
225
+ self, cfg_drop_ratio, sample_strategy, num_train_steps
226
+ )
227
+ ContentEncoderAdapterMixin.__init__(
228
+ self, content_encoder=content_encoder
229
+ )
230
+
231
+ self.autoencoder = autoencoder
232
+ for param in self.autoencoder.parameters():
233
+ param.requires_grad = False
234
+
235
+ if hasattr(
236
+ self.content_encoder, "audio_encoder"
237
+ ) and self.content_encoder.audio_encoder is not None:
238
+ self.content_encoder.audio_encoder.model = self.autoencoder
239
+
240
+ self.backbone = backbone
241
+ self.dummy_param = nn.Parameter(torch.empty(0))
242
+
243
+ def forward(
244
+ self, content: list[Any], condition: list[Any], task: list[str],
245
+ waveform: torch.Tensor, waveform_lengths: torch.Tensor, **kwargs
246
+ ):
247
+ device = self.dummy_param.device
248
+
249
+ self.autoencoder.eval()
250
+ with torch.no_grad():
251
+ latent, latent_mask = self.autoencoder.encode(
252
+ waveform.unsqueeze(1), waveform_lengths
253
+ )
254
+
255
+ content_dict = self.encode_content(content, task, device)
256
+ content, content_mask = content_dict["content"], content_dict[
257
+ "content_mask"]
258
+
259
+ if self.training and self.classifier_free_guidance:
260
+ mask_indices = [
261
+ k for k in range(len(waveform))
262
+ if random.random() < self.cfg_drop_ratio
263
+ ]
264
+ if len(mask_indices) > 0:
265
+ content[mask_indices] = 0
266
+
267
+ noisy_latent, target, timesteps = self.get_input_target_and_timesteps(
268
+ latent, training=self.training
269
+ )
270
+
271
+ pred: torch.Tensor = self.backbone(
272
+ x=noisy_latent,
273
+ timesteps=timesteps,
274
+ context=content,
275
+ x_mask=latent_mask,
276
+ context_mask=content_mask
277
+ )
278
+
279
+ loss = F.mse_loss(pred.float(), target.float(), reduction="none")
280
+ loss = loss_with_mask(loss, latent_mask)
281
+
282
+ return loss
283
+
284
+ def iterative_denoise(
285
+ self, latent: torch.Tensor, timesteps: list[int], num_steps: int,
286
+ verbose: bool, cfg: bool, cfg_scale: float, backbone_input: dict
287
+ ):
288
+ progress_bar = tqdm(range(num_steps), disable=not verbose)
289
+
290
+ for i, timestep in enumerate(timesteps):
291
+ # expand the latent if we are doing classifier free guidance
292
+ if cfg:
293
+ latent_input = torch.cat([latent, latent])
294
+ else:
295
+ latent_input = latent
296
+
297
+ noise_pred: torch.Tensor = self.backbone(
298
+ x=latent_input, timesteps=timestep, **backbone_input
299
+ )
300
+
301
+ # perform guidance
302
+ if cfg:
303
+ noise_pred_uncond, noise_pred_content = noise_pred.chunk(2)
304
+ noise_pred = noise_pred_uncond + cfg_scale * (
305
+ noise_pred_content - noise_pred_uncond
306
+ )
307
+
308
+ latent = self.infer_noise_scheduler.step(
309
+ noise_pred, timestep, latent
310
+ ).prev_sample
311
+
312
+ progress_bar.update(1)
313
+
314
+ progress_bar.close()
315
+
316
+ return latent
317
+
318
+ @torch.no_grad()
319
+ def inference(
320
+ self,
321
+ content: list[Any],
322
+ condition: list[Any],
323
+ task: list[str],
324
+ latent_shape: Sequence[int],
325
+ num_steps: int = 50,
326
+ sway_sampling_coef: float | None = -1.0,
327
+ guidance_scale: float = 3.0,
328
+ num_samples_per_content: int = 1,
329
+ disable_progress: bool = True,
330
+ **kwargs
331
+ ):
332
+ device = self.dummy_param.device
333
+ classifier_free_guidance = guidance_scale > 1.0
334
+ batch_size = len(content) * num_samples_per_content
335
+
336
+ if classifier_free_guidance:
337
+ content, content_mask = self.encode_content_classifier_free(
338
+ content, task, num_samples_per_content
339
+ )
340
+ else:
341
+ content_output: dict[
342
+ str, torch.Tensor] = self.content_encoder.encode_content(
343
+ content, task
344
+ )
345
+ content, content_mask = content_output["content"], content_output[
346
+ "content_mask"]
347
+ content = content.repeat_interleave(num_samples_per_content, 0)
348
+ content_mask = content_mask.repeat_interleave(
349
+ num_samples_per_content, 0
350
+ )
351
+
352
+ latent = self.prepare_latent(
353
+ batch_size, latent_shape, content.dtype, device
354
+ )
355
+
356
+ if not sway_sampling_coef:
357
+ sigmas = np.linspace(1.0, 1 / num_steps, num_steps)
358
+ else:
359
+ t = torch.linspace(0, 1, num_steps + 1)
360
+ t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
361
+ sigmas = 1 - t
362
+ timesteps, num_steps = self.retrieve_timesteps(
363
+ num_steps, device, timesteps=None, sigmas=sigmas
364
+ )
365
+
366
+ latent = self.iterative_denoise(
367
+ latent=latent,
368
+ timesteps=timesteps,
369
+ num_steps=num_steps,
370
+ verbose=not disable_progress,
371
+ cfg=classifier_free_guidance,
372
+ cfg_scale=guidance_scale,
373
+ backbone_input={
374
+ "context": content,
375
+ "context_mask": content_mask,
376
+ },
377
+ )
378
+
379
+ waveform = self.autoencoder.decode(latent)
380
+
381
+ return waveform
382
+
383
+ def prepare_latent(
384
+ self, batch_size: int, latent_shape: Sequence[int], dtype: torch.dtype,
385
+ device: str
386
+ ):
387
+ shape = (batch_size, *latent_shape)
388
+ latent = randn_tensor(
389
+ shape, generator=None, device=device, dtype=dtype
390
+ )
391
+ return latent
392
+
393
+ def encode_content_classifier_free(
394
+ self,
395
+ content: list[Any],
396
+ task: list[str],
397
+ device,
398
+ num_samples_per_content: int = 1
399
+ ):
400
+ content_dict = self.content_encoder.encode_content(
401
+ content, task, device=device
402
+ )
403
+ content, content_mask = content_dict["content"], content_dict[
404
+ "content_mask"]
405
+
406
+ content = content.repeat_interleave(num_samples_per_content, 0)
407
+ content_mask = content_mask.repeat_interleave(
408
+ num_samples_per_content, 0
409
+ )
410
+
411
+ # get unconditional embeddings for classifier free guidance
412
+ uncond_content = torch.zeros_like(content)
413
+ uncond_content_mask = content_mask.detach().clone()
414
+
415
+ uncond_content = uncond_content.repeat_interleave(
416
+ num_samples_per_content, 0
417
+ )
418
+ uncond_content_mask = uncond_content_mask.repeat_interleave(
419
+ num_samples_per_content, 0
420
+ )
421
+
422
+ # For classifier free guidance, we need to do two forward passes.
423
+ # We concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes
424
+ content = torch.cat([uncond_content, content])
425
+ content_mask = torch.cat([uncond_content_mask, content_mask])
426
+
427
+ return content, content_mask
428
+
429
+
430
+ class DurationAdapterMixin:
431
+ def __init__(
432
+ self,
433
+ latent_token_rate: int,
434
+ offset: float = 1.0,
435
+ frame_resolution: float | None = None
436
+ ):
437
+ self.latent_token_rate = latent_token_rate
438
+ self.offset = offset
439
+ self.frame_resolution = frame_resolution
440
+
441
+ def get_global_duration_loss(
442
+ self,
443
+ pred: torch.Tensor,
444
+ latent_mask: torch.Tensor,
445
+ reduce: bool = True,
446
+ ):
447
+ target = torch.log(
448
+ latent_mask.sum(1) / self.latent_token_rate + self.offset
449
+ )
450
+ loss = F.mse_loss(target, pred, reduction="mean" if reduce else "none")
451
+ return loss
452
+
453
+ def get_local_duration_loss(
454
+ self, ground_truth: torch.Tensor, pred: torch.Tensor,
455
+ mask: torch.Tensor, is_time_aligned: Sequence[bool], reduce: bool
456
+ ):
457
+ n_frames = torch.round(ground_truth / self.frame_resolution)
458
+ target = torch.log(n_frames + self.offset)
459
+ loss = loss_with_mask(
460
+ (target - pred)**2,
461
+ mask,
462
+ reduce=False,
463
+ )
464
+ loss *= is_time_aligned
465
+ if reduce:
466
+ if is_time_aligned.sum().item() == 0:
467
+ loss *= 0.0
468
+ loss = loss.mean()
469
+ else:
470
+ loss = loss.sum() / is_time_aligned.sum()
471
+
472
+ return loss
473
+
474
+ def prepare_local_duration(self, pred: torch.Tensor, mask: torch.Tensor):
475
+ pred = torch.exp(pred) * mask
476
+ pred = torch.ceil(pred) - self.offset
477
+ pred *= self.frame_resolution
478
+ return pred
479
+
480
+ def prepare_global_duration(
481
+ self,
482
+ global_pred: torch.Tensor,
483
+ local_pred: torch.Tensor,
484
+ is_time_aligned: Sequence[bool],
485
+ use_local: bool = True,
486
+ ):
487
+ """
488
+ global_pred: predicted duration value, processed by logarithmic and offset
489
+ local_pred: predicted latent length
490
+ """
491
+ global_pred = torch.exp(global_pred) - self.offset
492
+ result = global_pred
493
+ # avoid error accumulation for each frame
494
+ if use_local:
495
+ pred_from_local = torch.round(local_pred * self.latent_token_rate)
496
+ pred_from_local = pred_from_local.sum(1) / self.latent_token_rate
497
+ result[is_time_aligned] = pred_from_local[is_time_aligned]
498
+
499
+ return result
500
+
501
+ def expand_by_duration(
502
+ self,
503
+ x: torch.Tensor,
504
+ content_mask: torch.Tensor,
505
+ local_duration: torch.Tensor,
506
+ global_duration: torch.Tensor | None = None,
507
+ ):
508
+ n_latents = torch.round(local_duration * self.latent_token_rate)
509
+ if global_duration is not None:
510
+ latent_length = torch.round(
511
+ global_duration * self.latent_token_rate
512
+ )
513
+ else:
514
+ latent_length = n_latents.sum(1)
515
+ latent_mask = create_mask_from_length(latent_length).to(
516
+ content_mask.device
517
+ )
518
+ attn_mask = content_mask.unsqueeze(-1) * latent_mask.unsqueeze(1)
519
+ align_path = create_alignment_path(n_latents, attn_mask)
520
+ expanded_x = torch.matmul(align_path.transpose(1, 2).to(x.dtype), x)
521
+ return expanded_x, latent_mask
522
+
523
+
524
+ class CrossAttentionAudioFlowMatching(
525
+ SingleTaskCrossAttentionAudioFlowMatching, DurationAdapterMixin
526
+ ):
527
+ def __init__(
528
+ self,
529
+ autoencoder: AutoEncoderBase,
530
+ content_encoder: ContentEncoder,
531
+ content_adapter: ContentAdapterBase,
532
+ backbone: nn.Module,
533
+ content_dim: int,
534
+ frame_resolution: float,
535
+ duration_offset: float = 1.0,
536
+ cfg_drop_ratio: float = 0.2,
537
+ sample_strategy: str = 'normal',
538
+ num_train_steps: int = 1000
539
+ ):
540
+ super().__init__(
541
+ autoencoder=autoencoder,
542
+ content_encoder=content_encoder,
543
+ backbone=backbone,
544
+ cfg_drop_ratio=cfg_drop_ratio,
545
+ sample_strategy=sample_strategy,
546
+ num_train_steps=num_train_steps,
547
+ )
548
+ ContentEncoderAdapterMixin.__init__(
549
+ self,
550
+ content_encoder=content_encoder,
551
+ content_adapter=content_adapter
552
+ )
553
+ DurationAdapterMixin.__init__(
554
+ self,
555
+ latent_token_rate=autoencoder.latent_token_rate,
556
+ offset=duration_offset
557
+ )
558
+
559
+ def encode_content_with_instruction(
560
+ self, content: list[Any], task: list[str], device,
561
+ instruction: torch.Tensor, instruction_lengths: torch.Tensor
562
+ ):
563
+ content_dict = self.encode_content(
564
+ content, task, device, instruction, instruction_lengths
565
+ )
566
+ return (
567
+ content_dict["content"],
568
+ content_dict["content_mask"],
569
+ content_dict["global_duration_pred"],
570
+ content_dict["local_duration_pred"],
571
+ content_dict["length_aligned_content"],
572
+ )
573
+
574
+ def forward(
575
+ self,
576
+ content: list[Any],
577
+ task: list[str],
578
+ waveform: torch.Tensor,
579
+ waveform_lengths: torch.Tensor,
580
+ instruction: torch.Tensor,
581
+ instruction_lengths: torch.Tensor,
582
+ loss_reduce: bool = True,
583
+ **kwargs
584
+ ):
585
+ device = self.dummy_param.device
586
+ loss_reduce = self.training or (loss_reduce and not self.training)
587
+
588
+ self.autoencoder.eval()
589
+ with torch.no_grad():
590
+ latent, latent_mask = self.autoencoder.encode(
591
+ waveform.unsqueeze(1), waveform_lengths
592
+ )
593
+
594
+ content, content_mask, global_duration_pred, _, _ = \
595
+ self.encode_content_with_instruction(
596
+ content, task, device, instruction, instruction_lengths
597
+ )
598
+
599
+ global_duration_loss = self.get_global_duration_loss(
600
+ global_duration_pred, latent_mask, reduce=loss_reduce
601
+ )
602
+
603
+ if self.training and self.classifier_free_guidance:
604
+ mask_indices = [
605
+ k for k in range(len(waveform))
606
+ if random.random() < self.cfg_drop_ratio
607
+ ]
608
+ if len(mask_indices) > 0:
609
+ content[mask_indices] = 0
610
+
611
+ noisy_latent, target, timesteps = self.get_input_target_and_timesteps(
612
+ latent, training=self.training
613
+ )
614
+
615
+ pred: torch.Tensor = self.backbone(
616
+ x=noisy_latent,
617
+ timesteps=timesteps,
618
+ context=content,
619
+ x_mask=latent_mask,
620
+ context_mask=content_mask,
621
+ )
622
+ pred = pred.transpose(1, self.autoencoder.time_dim)
623
+ target = target.transpose(1, self.autoencoder.time_dim)
624
+ diff_loss = F.mse_loss(pred.float(), target.float(), reduction="none")
625
+ diff_loss = loss_with_mask(diff_loss, latent_mask, reduce=loss_reduce)
626
+
627
+ return {
628
+ "diff_loss": diff_loss,
629
+ "global_duration_loss": global_duration_loss,
630
+ }
631
+
632
+ @torch.no_grad()
633
+ def inference(
634
+ self,
635
+ content: list[Any],
636
+ condition: list[Any],
637
+ task: list[str],
638
+ is_time_aligned: Sequence[bool],
639
+ instruction: torch.Tensor,
640
+ instruction_lengths: torch.Tensor,
641
+ num_steps: int = 20,
642
+ sway_sampling_coef: float | None = -1.0,
643
+ guidance_scale: float = 3.0,
644
+ disable_progress=True,
645
+ use_gt_duration: bool = False,
646
+ **kwargs
647
+ ):
648
+ device = self.dummy_param.device
649
+ classifier_free_guidance = guidance_scale > 1.0
650
+
651
+ (
652
+ content,
653
+ content_mask,
654
+ global_duration_pred,
655
+ local_duration_pred,
656
+ _,
657
+ ) = self.encode_content_with_instruction(
658
+ content, task, device, instruction, instruction_lengths
659
+ )
660
+ batch_size = content.size(0)
661
+
662
+ if use_gt_duration:
663
+ raise NotImplementedError(
664
+ "Using ground truth global duration only is not implemented yet"
665
+ )
666
+
667
+ # prepare global duration
668
+ global_duration = self.prepare_global_duration(
669
+ global_duration_pred,
670
+ local_duration_pred,
671
+ is_time_aligned,
672
+ use_local=False
673
+ )
674
+ # TODO: manually set duration for SE and AudioSR
675
+ latent_length = torch.round(global_duration * self.latent_token_rate)
676
+ task_mask = torch.as_tensor([t in SAME_LENGTH_TASKS for t in task])
677
+ latent_length[task_mask] = content[task_mask].size(1)
678
+ latent_mask = create_mask_from_length(latent_length).to(device)
679
+ max_latent_length = latent_mask.sum(1).max().item()
680
+
681
+ # prepare latent and noise
682
+ if classifier_free_guidance:
683
+ uncond_context = torch.zeros_like(content)
684
+ uncond_content_mask = content_mask.detach().clone()
685
+ context = torch.cat([uncond_context, content])
686
+ context_mask = torch.cat([uncond_content_mask, content_mask])
687
+ else:
688
+ context = content
689
+ context_mask = content_mask
690
+
691
+ latent_shape = tuple(
692
+ max_latent_length if dim is None else dim
693
+ for dim in self.autoencoder.latent_shape
694
+ )
695
+ shape = (batch_size, *latent_shape)
696
+ latent = randn_tensor(
697
+ shape, generator=None, device=device, dtype=content.dtype
698
+ )
699
+ if not sway_sampling_coef:
700
+ sigmas = np.linspace(1.0, 1 / num_steps, num_steps)
701
+ else:
702
+ t = torch.linspace(0, 1, num_steps + 1)
703
+ t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
704
+ sigmas = 1 - t
705
+ timesteps, num_steps = self.retrieve_timesteps(
706
+ num_steps, device, timesteps=None, sigmas=sigmas
707
+ )
708
+ latent = self.iterative_denoise(
709
+ latent=latent,
710
+ timesteps=timesteps,
711
+ num_steps=num_steps,
712
+ verbose=not disable_progress,
713
+ cfg=classifier_free_guidance,
714
+ cfg_scale=guidance_scale,
715
+ backbone_input={
716
+ "x_mask": latent_mask,
717
+ "context": context,
718
+ "context_mask": context_mask,
719
+ }
720
+ )
721
+
722
+ waveform = self.autoencoder.decode(latent)
723
+ return waveform
724
+
725
+
726
+ class DummyContentAudioFlowMatching(CrossAttentionAudioFlowMatching):
727
+ def __init__(
728
+ self,
729
+ autoencoder: AutoEncoderBase,
730
+ content_encoder: ContentEncoder,
731
+ content_adapter: ContentAdapterBase,
732
+ backbone: nn.Module,
733
+ content_dim: int,
734
+ frame_resolution: float,
735
+ duration_offset: float = 1.0,
736
+ cfg_drop_ratio: float = 0.2,
737
+ sample_strategy: str = 'normal',
738
+ num_train_steps: int = 1000
739
+ ):
740
+
741
+ super().__init__(
742
+ autoencoder=autoencoder,
743
+ content_encoder=content_encoder,
744
+ content_adapter=content_adapter,
745
+ backbone=backbone,
746
+ content_dim=content_dim,
747
+ frame_resolution=frame_resolution,
748
+ duration_offset=duration_offset,
749
+ cfg_drop_ratio=cfg_drop_ratio,
750
+ sample_strategy=sample_strategy,
751
+ num_train_steps=num_train_steps
752
+ )
753
+ DurationAdapterMixin.__init__(
754
+ self,
755
+ latent_token_rate=autoencoder.latent_token_rate,
756
+ offset=duration_offset,
757
+ frame_resolution=frame_resolution
758
+ )
759
+ self.dummy_nta_embed = nn.Parameter(torch.zeros(content_dim))
760
+ self.dummy_ta_embed = nn.Parameter(torch.zeros(content_dim))
761
+
762
+ def get_backbone_input(
763
+ self, target_length: int, content: torch.Tensor,
764
+ content_mask: torch.Tensor, time_aligned_content: torch.Tensor,
765
+ length_aligned_content: torch.Tensor, is_time_aligned: torch.Tensor
766
+ ):
767
+ # TODO compatility for 2D spectrogram VAE
768
+ time_aligned_content = trim_or_pad_length(
769
+ time_aligned_content, target_length, 1
770
+ )
771
+ length_aligned_content = trim_or_pad_length(
772
+ length_aligned_content, target_length, 1
773
+ )
774
+ # time_aligned_content: from monotonic aligned input, without frame expansion (phoneme)
775
+ # length_aligned_content: from aligned input (f0/energy)
776
+ time_aligned_content = time_aligned_content + length_aligned_content
777
+ time_aligned_content[~is_time_aligned] = self.dummy_ta_embed.to(
778
+ time_aligned_content.dtype
779
+ )
780
+
781
+ context = content
782
+ context[is_time_aligned] = self.dummy_nta_embed.to(context.dtype)
783
+ # only use the first dummy non time aligned embedding
784
+ context_mask = content_mask.detach().clone()
785
+ context_mask[is_time_aligned, 1:] = False
786
+
787
+ # truncate dummy non time aligned context
788
+ if is_time_aligned.sum().item() < content.size(0):
789
+ trunc_nta_length = content_mask[~is_time_aligned].sum(1).max()
790
+ else:
791
+ trunc_nta_length = content.size(1)
792
+ context = context[:, :trunc_nta_length]
793
+ context_mask = context_mask[:, :trunc_nta_length]
794
+
795
+ return context, context_mask, time_aligned_content
796
+
797
+ def forward(
798
+ self,
799
+ content: list[Any],
800
+ duration: Sequence[float],
801
+ task: list[str],
802
+ is_time_aligned: Sequence[bool],
803
+ waveform: torch.Tensor,
804
+ waveform_lengths: torch.Tensor,
805
+ instruction: torch.Tensor,
806
+ instruction_lengths: torch.Tensor,
807
+ loss_reduce: bool = True,
808
+ **kwargs
809
+ ):
810
+ device = self.dummy_param.device
811
+ loss_reduce = self.training or (loss_reduce and not self.training)
812
+
813
+ self.autoencoder.eval()
814
+ with torch.no_grad():
815
+ latent, latent_mask = self.autoencoder.encode(
816
+ waveform.unsqueeze(1), waveform_lengths
817
+ )
818
+
819
+ (
820
+ content, content_mask, global_duration_pred, local_duration_pred,
821
+ length_aligned_content
822
+ ) = self.encode_content_with_instruction(
823
+ content, task, device, instruction, instruction_lengths
824
+ )
825
+
826
+ # truncate unused non time aligned duration prediction
827
+ if is_time_aligned.sum() > 0:
828
+ trunc_ta_length = content_mask[is_time_aligned].sum(1).max()
829
+ else:
830
+ trunc_ta_length = content.size(1)
831
+
832
+ # duration loss
833
+ local_duration_pred = local_duration_pred[:, :trunc_ta_length]
834
+ ta_content_mask = content_mask[:, :trunc_ta_length]
835
+ local_duration_loss = self.get_local_duration_loss(
836
+ duration,
837
+ local_duration_pred,
838
+ ta_content_mask,
839
+ is_time_aligned,
840
+ reduce=loss_reduce
841
+ )
842
+
843
+ global_duration_loss = self.get_global_duration_loss(
844
+ global_duration_pred, latent_mask, reduce=loss_reduce
845
+ )
846
+
847
+ # --------------------------------------------------------------------
848
+ # prepare latent and noise
849
+ # --------------------------------------------------------------------
850
+ noisy_latent, target, timesteps = self.get_input_target_and_timesteps(
851
+ latent, training=self.training
852
+ )
853
+
854
+ # --------------------------------------------------------------------
855
+ # duration adapter
856
+ # --------------------------------------------------------------------
857
+ if is_time_aligned.sum() == 0 and \
858
+ duration.size(1) < content_mask.size(1):
859
+ duration = F.pad(
860
+ duration, (0, content_mask.size(1) - duration.size(1))
861
+ )
862
+ time_aligned_content, _ = self.expand_by_duration(
863
+ x=content[:, :trunc_ta_length],
864
+ content_mask=ta_content_mask,
865
+ local_duration=duration,
866
+ )
867
+
868
+ # --------------------------------------------------------------------
869
+ # prepare input to the backbone
870
+ # --------------------------------------------------------------------
871
+ # TODO compatility for 2D spectrogram VAE
872
+ latent_length = noisy_latent.size(self.autoencoder.time_dim)
873
+ context, context_mask, time_aligned_content = self.get_backbone_input(
874
+ latent_length, content, content_mask, time_aligned_content,
875
+ length_aligned_content, is_time_aligned
876
+ )
877
+
878
+ # --------------------------------------------------------------------
879
+ # classifier free guidance
880
+ # --------------------------------------------------------------------
881
+ if self.training and self.classifier_free_guidance:
882
+ mask_indices = [
883
+ k for k in range(len(waveform))
884
+ if random.random() < self.cfg_drop_ratio
885
+ ]
886
+ if len(mask_indices) > 0:
887
+ context[mask_indices] = 0
888
+ time_aligned_content[mask_indices] = 0
889
+
890
+ pred: torch.Tensor = self.backbone(
891
+ x=noisy_latent,
892
+ x_mask=latent_mask,
893
+ timesteps=timesteps,
894
+ context=context,
895
+ context_mask=context_mask,
896
+ time_aligned_context=time_aligned_content,
897
+ )
898
+ pred = pred.transpose(1, self.autoencoder.time_dim)
899
+ target = target.transpose(1, self.autoencoder.time_dim)
900
+ diff_loss = F.mse_loss(pred, target, reduction="none")
901
+ diff_loss = loss_with_mask(diff_loss, latent_mask, reduce=loss_reduce)
902
+ return {
903
+ "diff_loss": diff_loss,
904
+ "local_duration_loss": local_duration_loss,
905
+ "global_duration_loss": global_duration_loss,
906
+ }
907
+
908
+ def inference(
909
+ self,
910
+ content: list[Any],
911
+ task: list[str],
912
+ is_time_aligned: Sequence[bool],
913
+ instruction: torch.Tensor,
914
+ instruction_lengths: Sequence[int],
915
+ num_steps: int = 20,
916
+ sway_sampling_coef: float | None = -1.0,
917
+ guidance_scale: float = 3.0,
918
+ disable_progress: bool = True,
919
+ use_gt_duration: bool = False,
920
+ **kwargs
921
+ ):
922
+ device = self.dummy_param.device
923
+ classifier_free_guidance = guidance_scale > 1.0
924
+
925
+ (
926
+ content, content_mask, global_duration_pred, local_duration_pred,
927
+ length_aligned_content
928
+ ) = self.encode_content_with_instruction(
929
+ content, task, device, instruction, instruction_lengths
930
+ )
931
+ # print("content std: ", content.std())
932
+ batch_size = content.size(0)
933
+
934
+ # truncate dummy time aligned duration prediction
935
+ is_time_aligned = torch.as_tensor(is_time_aligned)
936
+ if is_time_aligned.sum() > 0:
937
+ trunc_ta_length = content_mask[is_time_aligned].sum(1).max()
938
+ else:
939
+ trunc_ta_length = content.size(1)
940
+
941
+ # prepare local duration
942
+ local_duration = self.prepare_local_duration(
943
+ local_duration_pred, content_mask
944
+ )
945
+ local_duration = local_duration[:, :trunc_ta_length]
946
+ # use ground truth duration
947
+ if use_gt_duration and "duration" in kwargs:
948
+ local_duration = torch.as_tensor(kwargs["duration"]).to(device)
949
+
950
+ # prepare global duration
951
+ global_duration = self.prepare_global_duration(
952
+ global_duration_pred, local_duration, is_time_aligned
953
+ )
954
+
955
+ # --------------------------------------------------------------------
956
+ # duration adapter
957
+ # --------------------------------------------------------------------
958
+ time_aligned_content, latent_mask = self.expand_by_duration(
959
+ x=content[:, :trunc_ta_length],
960
+ content_mask=content_mask[:, :trunc_ta_length],
961
+ local_duration=local_duration,
962
+ global_duration=global_duration,
963
+ )
964
+
965
+ context, context_mask, time_aligned_content = self.get_backbone_input(
966
+ target_length=time_aligned_content.size(1),
967
+ content=content,
968
+ content_mask=content_mask,
969
+ time_aligned_content=time_aligned_content,
970
+ length_aligned_content=length_aligned_content,
971
+ is_time_aligned=is_time_aligned
972
+ )
973
+
974
+ # --------------------------------------------------------------------
975
+ # prepare unconditional input
976
+ # --------------------------------------------------------------------
977
+ if classifier_free_guidance:
978
+ uncond_time_aligned_content = torch.zeros_like(
979
+ time_aligned_content
980
+ )
981
+ uncond_context = torch.zeros_like(context)
982
+ uncond_context_mask = context_mask.detach().clone()
983
+ time_aligned_content = torch.cat([
984
+ uncond_time_aligned_content, time_aligned_content
985
+ ])
986
+ context = torch.cat([uncond_context, context])
987
+ context_mask = torch.cat([uncond_context_mask, context_mask])
988
+ latent_mask = torch.cat([
989
+ latent_mask, latent_mask.detach().clone()
990
+ ])
991
+
992
+ # --------------------------------------------------------------------
993
+ # prepare input to the backbone
994
+ # --------------------------------------------------------------------
995
+ latent_length = latent_mask.sum(1).max().item()
996
+ latent_shape = tuple(
997
+ latent_length if dim is None else dim
998
+ for dim in self.autoencoder.latent_shape
999
+ )
1000
+ shape = (batch_size, *latent_shape)
1001
+ latent = randn_tensor(
1002
+ shape, generator=None, device=device, dtype=content.dtype
1003
+ )
1004
+
1005
+ if not sway_sampling_coef:
1006
+ sigmas = np.linspace(1.0, 1 / num_steps, num_steps)
1007
+ else:
1008
+ t = torch.linspace(0, 1, num_steps + 1)
1009
+ t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
1010
+ sigmas = 1 - t
1011
+ timesteps, num_steps = self.retrieve_timesteps(
1012
+ num_steps, device, timesteps=None, sigmas=sigmas
1013
+ )
1014
+ latent = self.iterative_denoise(
1015
+ latent=latent,
1016
+ timesteps=timesteps,
1017
+ num_steps=num_steps,
1018
+ verbose=not disable_progress,
1019
+ cfg=classifier_free_guidance,
1020
+ cfg_scale=guidance_scale,
1021
+ backbone_input={
1022
+ "x_mask": latent_mask,
1023
+ "context": context,
1024
+ "context_mask": context_mask,
1025
+ "time_aligned_context": time_aligned_content,
1026
+ }
1027
+ )
1028
+
1029
+ waveform = self.autoencoder.decode(latent)
1030
+ return waveform
1031
+
1032
+
1033
+ class DoubleContentAudioFlowMatching(DummyContentAudioFlowMatching):
1034
+ def get_backbone_input(
1035
+ self, target_length: int, content: torch.Tensor,
1036
+ content_mask: torch.Tensor, time_aligned_content: torch.Tensor,
1037
+ length_aligned_content: torch.Tensor, is_time_aligned: torch.Tensor
1038
+ ):
1039
+ # TODO compatility for 2D spectrogram VAE
1040
+ time_aligned_content = trim_or_pad_length(
1041
+ time_aligned_content, target_length, 1
1042
+ )
1043
+ context_length = min(content.size(1), time_aligned_content.size(1))
1044
+ time_aligned_content[~is_time_aligned, :context_length] = content[
1045
+ ~is_time_aligned, :context_length]
1046
+ length_aligned_content = trim_or_pad_length(
1047
+ length_aligned_content, target_length, 1
1048
+ )
1049
+ # time_aligned_content: from monotonic aligned input, without frame expansion (phoneme)
1050
+ # length_aligned_content: from aligned input (f0/energy)
1051
+ time_aligned_content = time_aligned_content + length_aligned_content
1052
+
1053
+ context = content
1054
+ context_mask = content_mask.detach().clone()
1055
+
1056
+ return context, context_mask, time_aligned_content
1057
+
1058
+
1059
+ class HybridContentAudioFlowMatching(DummyContentAudioFlowMatching):
1060
+ def get_backbone_input(
1061
+ self, target_length: int, content: torch.Tensor,
1062
+ content_mask: torch.Tensor, time_aligned_content: torch.Tensor,
1063
+ length_aligned_content: torch.Tensor, is_time_aligned: torch.Tensor
1064
+ ):
1065
+ # TODO compatility for 2D spectrogram VAE
1066
+ time_aligned_content = trim_or_pad_length(
1067
+ time_aligned_content, target_length, 1
1068
+ )
1069
+ length_aligned_content = trim_or_pad_length(
1070
+ length_aligned_content, target_length, 1
1071
+ )
1072
+ # time_aligned_content: from monotonic aligned input, without frame expansion (phoneme)
1073
+ # length_aligned_content: from aligned input (f0/energy)
1074
+ time_aligned_content = time_aligned_content + length_aligned_content
1075
+ time_aligned_content[~is_time_aligned] = self.dummy_ta_embed.to(
1076
+ time_aligned_content.dtype
1077
+ )
1078
+
1079
+ context = content
1080
+ context_mask = content_mask.detach().clone()
1081
+
1082
+ return context, context_mask, time_aligned_content
requirements.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchdata
3
+ diffusers
4
+ hydra-core
5
+ omegaconf
6
+ tqdm
7
+ accelerate
8
+ wandb
9
+ einops
10
+ transformers
11
+ alias_free_torch
12
+ h5py
13
+ torchaudio
14
+ soundfile
15
+ tensorboard
16
+ swanlab
17
+ fire
18
+ sentencepiece
19
+ librosa
20
+ pypinyin
21
+ g2p_en
22
+ git+https://github.com/wenet-e2e/wespeaker.git
utils/accelerate_utilities.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from accelerate import Accelerator
2
+
3
+
4
+ class AcceleratorSaveTrainableParams(Accelerator):
5
+ def get_state_dict(self, model, unwrap=True):
6
+ state_dict = super().get_state_dict(model, unwrap)
7
+ if hasattr(model, "param_names_to_save"):
8
+ param_names_to_save = model.param_names_to_save
9
+ return {
10
+ k: v
11
+ for k, v in state_dict.items() if k in param_names_to_save
12
+ }
13
+ return state_dict
utils/audio.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchaudio
4
+
5
+
6
+ class PadCrop(nn.Module):
7
+ def __init__(self, n_samples, randomize=True):
8
+ super().__init__()
9
+ self.n_samples = n_samples
10
+ self.randomize = randomize
11
+
12
+ def __call__(self, signal):
13
+ n, s = signal.shape
14
+ start = 0 if (
15
+ not self.randomize
16
+ ) else torch.randint(0,
17
+ max(0, s - self.n_samples) + 1, []).item()
18
+ end = start + self.n_samples
19
+ output = signal.new_zeros([n, self.n_samples])
20
+ output[:, :min(s, self.n_samples)] = signal[:, start:end]
21
+ return output
22
+
23
+
24
+ def set_audio_channels(audio, target_channels):
25
+ if target_channels == 1:
26
+ # Convert to mono
27
+ audio = audio.mean(1, keepdim=True)
28
+ elif target_channels == 2:
29
+ # Convert to stereo
30
+ if audio.shape[1] == 1:
31
+ audio = audio.repeat(1, 2, 1)
32
+ elif audio.shape[1] > 2:
33
+ audio = audio[:, :2, :]
34
+ return audio
35
+
36
+
37
+ def prepare_audio(
38
+ audio, in_sr, target_sr, target_length, target_channels, device
39
+ ):
40
+
41
+ audio = audio.to(device)
42
+
43
+ if in_sr != target_sr:
44
+ resample_tf = torchaudio.transforms.Resample(in_sr,
45
+ target_sr).to(device)
46
+ audio = resample_tf(audio)
47
+
48
+ audio = PadCrop(target_length, randomize=False)(audio)
49
+
50
+ # Add batch dimension
51
+ if audio.dim() == 1:
52
+ audio = audio.unsqueeze(0).unsqueeze(0)
53
+ elif audio.dim() == 2:
54
+ audio = audio.unsqueeze(0)
55
+
56
+ audio = set_audio_channels(audio, target_channels)
57
+
58
+ return audio
utils/config.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import sys
3
+ import os
4
+
5
+ import hydra
6
+ import omegaconf
7
+ from omegaconf import OmegaConf
8
+
9
+
10
+ def multiply(*args):
11
+ result = 1
12
+ for arg in args:
13
+ result *= arg
14
+ return result
15
+
16
+
17
+ def get_pitch_downsample_ratio(
18
+ autoencoder_config: dict, pitch_frame_resolution: float
19
+ ):
20
+ latent_frame_resolution = autoencoder_config[
21
+ "downsampling_ratio"] / autoencoder_config["sample_rate"]
22
+ return round(latent_frame_resolution / pitch_frame_resolution)
23
+
24
+
25
+ def register_omegaconf_resolvers() -> None:
26
+ """
27
+ Register custom resolver for hydra configs, which can be used in YAML
28
+ files for dynamically setting values
29
+ """
30
+ OmegaConf.clear_resolvers()
31
+ OmegaConf.register_new_resolver("len", len, replace=True)
32
+ OmegaConf.register_new_resolver("multiply", multiply, replace=True)
33
+ OmegaConf.register_new_resolver(
34
+ "get_pitch_downsample_ratio", get_pitch_downsample_ratio, replace=True
35
+ )
36
+
37
+
38
+ def generate_config_from_command_line_overrides(
39
+ config_file: str | Path
40
+ ) -> omegaconf.DictConfig:
41
+ register_omegaconf_resolvers()
42
+
43
+ config_file = Path(config_file).resolve()
44
+ config_name = config_file.name.__str__()
45
+ config_path = config_file.parent.__str__()
46
+ config_path = os.path.relpath(config_path, Path(__file__).resolve().parent)
47
+
48
+ overrides = sys.argv[1:]
49
+ with hydra.initialize(version_base=None, config_path=config_path):
50
+ config = hydra.compose(config_name=config_name, overrides=overrides)
51
+ omegaconf.OmegaConf.resolve(config)
52
+
53
+ return config
utils/diffsinger_utilities.py ADDED
@@ -0,0 +1,721 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import six
2
+ from pathlib import Path
3
+ import re
4
+ import json
5
+ from collections import OrderedDict
6
+ from typing import Union
7
+
8
+ from pypinyin import pinyin, lazy_pinyin, Style
9
+ import numpy as np
10
+ import librosa
11
+ import torch
12
+
13
+ PAD = "<pad>"
14
+ EOS = "<EOS>"
15
+ UNK = "<UNK>"
16
+ SEG = "|"
17
+ RESERVED_TOKENS = [PAD, EOS, UNK]
18
+ NUM_RESERVED_TOKENS = len(RESERVED_TOKENS)
19
+ PAD_ID = RESERVED_TOKENS.index(PAD) # Normally 0
20
+ EOS_ID = RESERVED_TOKENS.index(EOS) # Normally 1
21
+ UNK_ID = RESERVED_TOKENS.index(UNK) # Normally 2
22
+
23
+ F0_BIN = 256
24
+ F0_MAX = 1100.0
25
+ F0_MIN = 50.0
26
+ F0_MEL_MIN = 1127 * np.log(1 + F0_MIN / 700)
27
+ F0_MEL_MAX = 1127 * np.log(1 + F0_MAX / 700)
28
+
29
+
30
+ def f0_to_coarse(f0):
31
+ is_torch = isinstance(f0, torch.Tensor)
32
+ f0_mel = 1127 * (1 + f0 /
33
+ 700).log() if is_torch else 1127 * np.log(1 + f0 / 700)
34
+ f0_mel[f0_mel > 0
35
+ ] = (f0_mel[f0_mel > 0] -
36
+ F0_MEL_MIN) * (F0_BIN - 2) / (F0_MEL_MAX - F0_MEL_MIN) + 1
37
+
38
+ f0_mel[f0_mel <= 1] = 1
39
+ f0_mel[f0_mel > F0_BIN - 1] = F0_BIN - 1
40
+ f0_coarse = (f0_mel +
41
+ 0.5).long() if is_torch else np.rint(f0_mel).astype(int)
42
+ assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, (
43
+ f0_coarse.max(), f0_coarse.min()
44
+ )
45
+ return f0_coarse
46
+
47
+
48
+ def norm_f0(
49
+ f0: Union[np.ndarray, torch.Tensor],
50
+ uv: Union[None, np.ndarray],
51
+ f0_mean: float,
52
+ f0_std: float,
53
+ pitch_norm: str = "log",
54
+ use_uv: bool = True
55
+ ):
56
+ is_torch = isinstance(f0, torch.Tensor)
57
+ if pitch_norm == 'standard':
58
+ f0 = (f0 - f0_mean) / f0_std
59
+ if pitch_norm == 'log':
60
+ f0 = torch.log2(f0) if is_torch else np.log2(f0)
61
+ if uv is not None and use_uv:
62
+ f0[uv > 0] = 0
63
+ return f0
64
+
65
+
66
+ def norm_interp_f0(
67
+ f0: Union[np.ndarray, torch.Tensor],
68
+ f0_mean: float,
69
+ f0_std: float,
70
+ pitch_norm: str = "log",
71
+ use_uv: bool = True
72
+ ):
73
+ is_torch = isinstance(f0, torch.Tensor)
74
+ if is_torch:
75
+ device = f0.device
76
+ f0 = f0.data.cpu().numpy()
77
+ uv = f0 == 0
78
+ f0 = norm_f0(f0, uv, f0_mean, f0_std, pitch_norm, use_uv)
79
+ if sum(uv) == len(f0):
80
+ f0[uv] = 0
81
+ elif sum(uv) > 0:
82
+ f0[uv] = np.interp(np.where(uv)[0], np.where(~uv)[0], f0[~uv])
83
+ uv = torch.as_tensor(uv).float()
84
+ f0 = torch.as_tensor(f0).float()
85
+ if is_torch:
86
+ f0 = f0.to(device)
87
+ return f0, uv
88
+
89
+
90
+ def denorm_f0(
91
+ f0,
92
+ uv,
93
+ pitch_norm="log",
94
+ f0_mean=None,
95
+ f0_std=None,
96
+ pitch_padding=None,
97
+ min=None,
98
+ max=None,
99
+ use_uv=True
100
+ ):
101
+ if pitch_norm == 'standard':
102
+ f0 = f0 * f0_std + f0_mean
103
+ if pitch_norm == 'log':
104
+ f0 = 2**f0
105
+ if min is not None:
106
+ f0 = f0.clamp(min=min)
107
+ if max is not None:
108
+ f0 = f0.clamp(max=max)
109
+ if uv is not None and use_uv:
110
+ f0[uv > 0] = 0
111
+ if pitch_padding is not None:
112
+ f0[pitch_padding] = 0
113
+ return f0
114
+
115
+
116
+ def librosa_pad_lr(x, fshift, pad_sides=1):
117
+ '''compute right padding (final frame) or both sides padding (first and final frames)
118
+ '''
119
+ assert pad_sides in (1, 2)
120
+ # return int(fsize // 2)
121
+ pad = (x.shape[0] // fshift + 1) * fshift - x.shape[0]
122
+ if pad_sides == 1:
123
+ return 0, pad
124
+ else:
125
+ return pad // 2, pad // 2 + pad % 2
126
+
127
+
128
+ def get_pitch(
129
+ wav_file: Union[str, Path], sample_rate: int, frame_shift: float
130
+ ):
131
+ import parselmouth
132
+ hop_size = int(frame_shift * sample_rate)
133
+ wav, _ = librosa.core.load(wav_file, sr=sample_rate)
134
+ # l_pad, r_pad = librosa_pad_lr(wav, hop_size, 1)
135
+ # wav = np.pad(wav, (l_pad, r_pad), mode='constant', constant_values=0.0)
136
+
137
+ latent_length = wav.shape[0] // hop_size
138
+ f0_min = 80
139
+ f0_max = 750
140
+ pad_size = 4
141
+
142
+ f0 = parselmouth.Sound(wav, sample_rate).to_pitch_ac(
143
+ time_step=frame_shift,
144
+ voicing_threshold=0.6,
145
+ pitch_floor=f0_min,
146
+ pitch_ceiling=f0_max
147
+ ).selected_array['frequency']
148
+ delta_l = latent_length - len(f0)
149
+ if delta_l > 0:
150
+ f0 = np.concatenate([f0, [f0[-1]] * delta_l], 0)
151
+ pitch_coarse = f0_to_coarse(f0)
152
+ return f0, pitch_coarse
153
+
154
+
155
+ def remove_empty_lines(text):
156
+ """remove empty lines"""
157
+ assert (len(text) > 0)
158
+ assert (isinstance(text, list))
159
+ text = [t.strip() for t in text]
160
+ if "" in text:
161
+ text.remove("")
162
+ return text
163
+
164
+
165
+ def is_sil_phoneme(p):
166
+ return not p[0].isalpha()
167
+
168
+
169
+ def strip_ids(ids, ids_to_strip):
170
+ """Strip ids_to_strip from the end ids."""
171
+ ids = list(ids)
172
+ while ids and ids[-1] in ids_to_strip:
173
+ ids.pop()
174
+ return ids
175
+
176
+
177
+ class TextEncoder(object):
178
+ """Base class for converting from ints to/from human readable strings."""
179
+ def __init__(self, num_reserved_ids=NUM_RESERVED_TOKENS):
180
+ self._num_reserved_ids = num_reserved_ids
181
+
182
+ @property
183
+ def num_reserved_ids(self):
184
+ return self._num_reserved_ids
185
+
186
+ def encode(self, s):
187
+ """Transform a human-readable string into a sequence of int ids.
188
+
189
+ The ids should be in the range [num_reserved_ids, vocab_size). Ids [0,
190
+ num_reserved_ids) are reserved.
191
+
192
+ EOS is not appended.
193
+
194
+ Args:
195
+ s: human-readable string to be converted.
196
+
197
+ Returns:
198
+ ids: list of integers
199
+ """
200
+ return [int(w) + self._num_reserved_ids for w in s.split()]
201
+
202
+ def decode(self, ids, strip_extraneous=False):
203
+ """Transform a sequence of int ids into a human-readable string.
204
+
205
+ EOS is not expected in ids.
206
+
207
+ Args:
208
+ ids: list of integers to be converted.
209
+ strip_extraneous: bool, whether to strip off extraneous tokens
210
+ (EOS and PAD).
211
+
212
+ Returns:
213
+ s: human-readable string.
214
+ """
215
+ if strip_extraneous:
216
+ ids = strip_ids(ids, list(range(self._num_reserved_ids or 0)))
217
+ return " ".join(self.decode_list(ids))
218
+
219
+ def decode_list(self, ids):
220
+ """Transform a sequence of int ids into a their string versions.
221
+
222
+ This method supports transforming individual input/output ids to their
223
+ string versions so that sequence to/from text conversions can be visualized
224
+ in a human readable format.
225
+
226
+ Args:
227
+ ids: list of integers to be converted.
228
+
229
+ Returns:
230
+ strs: list of human-readable string.
231
+ """
232
+ decoded_ids = []
233
+ for id_ in ids:
234
+ if 0 <= id_ < self._num_reserved_ids:
235
+ decoded_ids.append(RESERVED_TOKENS[int(id_)])
236
+ else:
237
+ decoded_ids.append(id_ - self._num_reserved_ids)
238
+ return [str(d) for d in decoded_ids]
239
+
240
+ @property
241
+ def vocab_size(self):
242
+ raise NotImplementedError()
243
+
244
+
245
+ class TokenTextEncoder(TextEncoder):
246
+ """Encoder based on a user-supplied vocabulary (file or list)."""
247
+ def __init__(
248
+ self,
249
+ vocab_filename,
250
+ reverse=False,
251
+ vocab_list=None,
252
+ replace_oov=None,
253
+ num_reserved_ids=NUM_RESERVED_TOKENS
254
+ ):
255
+ """Initialize from a file or list, one token per line.
256
+
257
+ Handling of reserved tokens works as follows:
258
+ - When initializing from a list, we add reserved tokens to the vocab.
259
+ - When initializing from a file, we do not add reserved tokens to the vocab.
260
+ - When saving vocab files, we save reserved tokens to the file.
261
+
262
+ Args:
263
+ vocab_filename: If not None, the full filename to read vocab from. If this
264
+ is not None, then vocab_list should be None.
265
+ reverse: Boolean indicating if tokens should be reversed during encoding
266
+ and decoding.
267
+ vocab_list: If not None, a list of elements of the vocabulary. If this is
268
+ not None, then vocab_filename should be None.
269
+ replace_oov: If not None, every out-of-vocabulary token seen when
270
+ encoding will be replaced by this string (which must be in vocab).
271
+ num_reserved_ids: Number of IDs to save for reserved tokens like <EOS>.
272
+ """
273
+ super(TokenTextEncoder,
274
+ self).__init__(num_reserved_ids=num_reserved_ids)
275
+ self._reverse = reverse
276
+ self._replace_oov = replace_oov
277
+ if vocab_filename:
278
+ self._init_vocab_from_file(vocab_filename)
279
+ else:
280
+ assert vocab_list is not None
281
+ self._init_vocab_from_list(vocab_list)
282
+ self.pad_index = self._token_to_id[PAD]
283
+ self.eos_index = self._token_to_id[EOS]
284
+ self.unk_index = self._token_to_id[UNK]
285
+ self.seg_index = self._token_to_id[
286
+ SEG] if SEG in self._token_to_id else self.eos_index
287
+
288
+ def encode(self, s):
289
+ """Converts a space-separated string of tokens to a list of ids."""
290
+ sentence = s
291
+ tokens = sentence.strip().split()
292
+ if self._replace_oov is not None:
293
+ tokens = [
294
+ t if t in self._token_to_id else self._replace_oov
295
+ for t in tokens
296
+ ]
297
+ ret = [self._token_to_id[tok] for tok in tokens]
298
+ return ret[::-1] if self._reverse else ret
299
+
300
+ def decode(self, ids, strip_eos=False, strip_padding=False):
301
+ if strip_padding and self.pad() in list(ids):
302
+ pad_pos = list(ids).index(self.pad())
303
+ ids = ids[:pad_pos]
304
+ if strip_eos and self.eos() in list(ids):
305
+ eos_pos = list(ids).index(self.eos())
306
+ ids = ids[:eos_pos]
307
+ return " ".join(self.decode_list(ids))
308
+
309
+ def decode_list(self, ids):
310
+ seq = reversed(ids) if self._reverse else ids
311
+ return [self._safe_id_to_token(i) for i in seq]
312
+
313
+ @property
314
+ def vocab_size(self):
315
+ return len(self._id_to_token)
316
+
317
+ def __len__(self):
318
+ return self.vocab_size
319
+
320
+ def _safe_id_to_token(self, idx):
321
+ return self._id_to_token.get(idx, "ID_%d" % idx)
322
+
323
+ def _init_vocab_from_file(self, filename):
324
+ """Load vocab from a file.
325
+
326
+ Args:
327
+ filename: The file to load vocabulary from.
328
+ """
329
+ with open(filename) as f:
330
+ tokens = [token.strip() for token in f.readlines()]
331
+
332
+ def token_gen():
333
+ for token in tokens:
334
+ yield token
335
+
336
+ self._init_vocab(token_gen(), add_reserved_tokens=False)
337
+
338
+ def _init_vocab_from_list(self, vocab_list):
339
+ """Initialize tokens from a list of tokens.
340
+
341
+ It is ok if reserved tokens appear in the vocab list. They will be
342
+ removed. The set of tokens in vocab_list should be unique.
343
+
344
+ Args:
345
+ vocab_list: A list of tokens.
346
+ """
347
+ def token_gen():
348
+ for token in vocab_list:
349
+ if token not in RESERVED_TOKENS:
350
+ yield token
351
+
352
+ self._init_vocab(token_gen())
353
+
354
+ def _init_vocab(self, token_generator, add_reserved_tokens=True):
355
+ """Initialize vocabulary with tokens from token_generator."""
356
+
357
+ self._id_to_token = {}
358
+ non_reserved_start_index = 0
359
+
360
+ if add_reserved_tokens:
361
+ self._id_to_token.update(enumerate(RESERVED_TOKENS))
362
+ non_reserved_start_index = len(RESERVED_TOKENS)
363
+
364
+ self._id_to_token.update(
365
+ enumerate(token_generator, start=non_reserved_start_index)
366
+ )
367
+
368
+ # _token_to_id is the reverse of _id_to_token
369
+ self._token_to_id = dict((v, k)
370
+ for k, v in six.iteritems(self._id_to_token))
371
+
372
+ def pad(self):
373
+ return self.pad_index
374
+
375
+ def eos(self):
376
+ return self.eos_index
377
+
378
+ def unk(self):
379
+ return self.unk_index
380
+
381
+ def seg(self):
382
+ return self.seg_index
383
+
384
+ def store_to_file(self, filename):
385
+ """Write vocab file to disk.
386
+
387
+ Vocab files have one token per line. The file ends in a newline. Reserved
388
+ tokens are written to the vocab file as well.
389
+
390
+ Args:
391
+ filename: Full path of the file to store the vocab to.
392
+ """
393
+ with open(filename, "w") as f:
394
+ for i in range(len(self._id_to_token)):
395
+ f.write(self._id_to_token[i] + "\n")
396
+
397
+ def sil_phonemes(self):
398
+ return [p for p in self._id_to_token.values() if not p[0].isalpha()]
399
+
400
+
401
+ class TextGrid(object):
402
+ def __init__(self, text):
403
+ text = remove_empty_lines(text)
404
+ self.text = text
405
+ self.line_count = 0
406
+ self._get_type()
407
+ self._get_time_intval()
408
+ self._get_size()
409
+ self.tier_list = []
410
+ self._get_item_list()
411
+
412
+ def _extract_pattern(self, pattern, inc):
413
+ """
414
+ Parameters
415
+ ----------
416
+ pattern : regex to extract pattern
417
+ inc : increment of line count after extraction
418
+ Returns
419
+ -------
420
+ group : extracted info
421
+ """
422
+ try:
423
+ group = re.match(pattern, self.text[self.line_count]).group(1)
424
+ self.line_count += inc
425
+ except AttributeError:
426
+ raise ValueError(
427
+ "File format error at line %d:%s" %
428
+ (self.line_count, self.text[self.line_count])
429
+ )
430
+ return group
431
+
432
+ def _get_type(self):
433
+ self.file_type = self._extract_pattern(r"File type = \"(.*)\"", 2)
434
+
435
+ def _get_time_intval(self):
436
+ self.xmin = self._extract_pattern(r"xmin = (.*)", 1)
437
+ self.xmax = self._extract_pattern(r"xmax = (.*)", 2)
438
+
439
+ def _get_size(self):
440
+ self.size = int(self._extract_pattern(r"size = (.*)", 2))
441
+
442
+ def _get_item_list(self):
443
+ """Only supports IntervalTier currently"""
444
+ for itemIdx in range(1, self.size + 1):
445
+ tier = OrderedDict()
446
+ item_list = []
447
+ tier_idx = self._extract_pattern(r"item \[(.*)\]:", 1)
448
+ tier_class = self._extract_pattern(r"class = \"(.*)\"", 1)
449
+ if tier_class != "IntervalTier":
450
+ raise NotImplementedError(
451
+ "Only IntervalTier class is supported currently"
452
+ )
453
+ tier_name = self._extract_pattern(r"name = \"(.*)\"", 1)
454
+ tier_xmin = self._extract_pattern(r"xmin = (.*)", 1)
455
+ tier_xmax = self._extract_pattern(r"xmax = (.*)", 1)
456
+ tier_size = self._extract_pattern(r"intervals: size = (.*)", 1)
457
+ for i in range(int(tier_size)):
458
+ item = OrderedDict()
459
+ item["idx"] = self._extract_pattern(r"intervals \[(.*)\]", 1)
460
+ item["xmin"] = self._extract_pattern(r"xmin = (.*)", 1)
461
+ item["xmax"] = self._extract_pattern(r"xmax = (.*)", 1)
462
+ item["text"] = self._extract_pattern(r"text = \"(.*)\"", 1)
463
+ item_list.append(item)
464
+ tier["idx"] = tier_idx
465
+ tier["class"] = tier_class
466
+ tier["name"] = tier_name
467
+ tier["xmin"] = tier_xmin
468
+ tier["xmax"] = tier_xmax
469
+ tier["size"] = tier_size
470
+ tier["items"] = item_list
471
+ self.tier_list.append(tier)
472
+
473
+ def toJson(self):
474
+ _json = OrderedDict()
475
+ _json["file_type"] = self.file_type
476
+ _json["xmin"] = self.xmin
477
+ _json["xmax"] = self.xmax
478
+ _json["size"] = self.size
479
+ _json["tiers"] = self.tier_list
480
+ return json.dumps(_json, ensure_ascii=False, indent=2)
481
+
482
+
483
+ def read_duration_from_textgrid(
484
+ textgrid_path: Union[str, Path],
485
+ phoneme: str,
486
+ utterance_duration: float,
487
+ ):
488
+ ph_list = phoneme.split(" ")
489
+ with open(textgrid_path, "r") as f:
490
+ textgrid = f.readlines()
491
+ textgrid = remove_empty_lines(textgrid)
492
+ textgrid = TextGrid(textgrid)
493
+ textgrid = json.loads(textgrid.toJson())
494
+
495
+ split = np.ones(len(ph_list) + 1, np.float32) * -1
496
+ tg_idx = 0
497
+ ph_idx = 0
498
+ tg_align = [x for x in textgrid['tiers'][-1]['items']]
499
+ tg_align_ = []
500
+ for x in tg_align:
501
+ x['xmin'] = float(x['xmin'])
502
+ x['xmax'] = float(x['xmax'])
503
+ if x['text'] in ['sil', 'sp', '', 'SIL', 'PUNC', '<SP>', '<AP>']:
504
+ x['text'] = ''
505
+ if len(tg_align_) > 0 and tg_align_[-1]['text'] == '':
506
+ tg_align_[-1]['xmax'] = x['xmax']
507
+ continue
508
+ tg_align_.append(x)
509
+ tg_align = tg_align_
510
+ tg_len = len([x for x in tg_align if x['text'] != ''])
511
+ ph_len = len([x for x in ph_list if not is_sil_phoneme(x)])
512
+ assert tg_len == ph_len, (tg_len, ph_len, tg_align, ph_list, textgrid_path)
513
+ while tg_idx < len(tg_align) or ph_idx < len(ph_list):
514
+ if tg_idx == len(tg_align) and is_sil_phoneme(ph_list[ph_idx]):
515
+ split[ph_idx] = 1e8
516
+ ph_idx += 1
517
+ continue
518
+ x = tg_align[tg_idx]
519
+ if x['text'] == '' and ph_idx == len(ph_list):
520
+ tg_idx += 1
521
+ continue
522
+ assert ph_idx < len(ph_list), (
523
+ tg_len, ph_len, tg_align, ph_list, textgrid_path
524
+ )
525
+
526
+ ph = ph_list[ph_idx]
527
+ if x['text'] == '' and not is_sil_phoneme(ph):
528
+ assert False, (ph_list, tg_align)
529
+ if x['text'] != '' and is_sil_phoneme(ph):
530
+ ph_idx += 1
531
+ else:
532
+ assert (x['text'] == '' and is_sil_phoneme(ph)) \
533
+ or x['text'].lower() == ph.lower() \
534
+ or x['text'].lower() == 'sil', (x['text'], ph)
535
+ split[ph_idx] = x['xmin']
536
+ if ph_idx > 0 and split[ph_idx - 1] == -1 and is_sil_phoneme(
537
+ ph_list[ph_idx - 1]
538
+ ):
539
+ split[ph_idx - 1] = split[ph_idx]
540
+ ph_idx += 1
541
+ tg_idx += 1
542
+ assert tg_idx == len(tg_align), (tg_idx, [x['text'] for x in tg_align])
543
+ assert ph_idx >= len(ph_list) - 1, (
544
+ ph_idx, ph_list, len(ph_list), [x['text']
545
+ for x in tg_align], textgrid_path
546
+ )
547
+
548
+ split[0] = 0
549
+ split[-1] = utterance_duration
550
+ duration = np.diff(split)
551
+ return duration
552
+
553
+
554
+ class SVSInputConverter:
555
+ def build_pinyin_ph_mapping(self, pinyin2ph: str):
556
+ pinyin2phs = {'AP': '<AP>', 'SP': '<SP>'}
557
+ with open(pinyin2ph) as rf:
558
+ for line in rf.readlines():
559
+ elements = [
560
+ x.strip() for x in line.split('|') if x.strip() != ''
561
+ ]
562
+ pinyin2phs[elements[0]] = elements[1]
563
+ return pinyin2phs
564
+
565
+ def __init__(self, singer_map: dict, pinyin2ph: str):
566
+ self.pinyin2phs = self.build_pinyin_ph_mapping(pinyin2ph)
567
+ self.spk_map = singer_map
568
+
569
+ def preprocess_word_level_input(self, inp):
570
+ # Pypinyin can't solve polyphonic words
571
+ text_raw = inp['text']
572
+
573
+ # lyric
574
+ pinyins = lazy_pinyin(text_raw, strict=False)
575
+ ph_per_word_lst = [
576
+ self.pinyin2phs[pinyin.strip()]
577
+ for pinyin in pinyins if pinyin.strip() in self.pinyin2phs
578
+ ]
579
+
580
+ # Note
581
+ note_per_word_lst = [
582
+ x.strip() for x in inp['notes'].split('|') if x.strip() != ''
583
+ ]
584
+ mididur_per_word_lst = [
585
+ x.strip()
586
+ for x in inp['notes_duration'].split('|') if x.strip() != ''
587
+ ]
588
+
589
+ if len(note_per_word_lst) == len(ph_per_word_lst
590
+ ) == len(mididur_per_word_lst):
591
+ print('Pass word-notes check.')
592
+ else:
593
+ print(
594
+ 'The number of words does\'t match the number of notes\' windows. ',
595
+ 'You should split the note(s) for each word by | mark.'
596
+ )
597
+ print(ph_per_word_lst, note_per_word_lst, mididur_per_word_lst)
598
+ print(
599
+ len(ph_per_word_lst), len(note_per_word_lst),
600
+ len(mididur_per_word_lst)
601
+ )
602
+ return None
603
+
604
+ note_lst = []
605
+ ph_lst = []
606
+ midi_dur_lst = []
607
+ is_slur = []
608
+ for idx, ph_per_word in enumerate(ph_per_word_lst):
609
+ # for phs in one word:
610
+ # single ph like ['ai'] or multiple phs like ['n', 'i']
611
+ ph_in_this_word = ph_per_word.split()
612
+
613
+ # for notes in one word:
614
+ # single note like ['D4'] or multiple notes like ['D4', 'E4'] which means a 'slur' here.
615
+ note_in_this_word = note_per_word_lst[idx].split()
616
+ midi_dur_in_this_word = mididur_per_word_lst[idx].split()
617
+ # process for the model input
618
+ # Step 1.
619
+ # Deal with note of 'not slur' case or the first note of 'slur' case
620
+ # j ie
621
+ # F#4/Gb4 F#4/Gb4
622
+ # 0 0
623
+ for ph in ph_in_this_word:
624
+ ph_lst.append(ph)
625
+ note_lst.append(note_in_this_word[0])
626
+ midi_dur_lst.append(midi_dur_in_this_word[0])
627
+ is_slur.append(0)
628
+ # step 2.
629
+ # Deal with the 2nd, 3rd... notes of 'slur' case
630
+ # j ie ie
631
+ # F#4/Gb4 F#4/Gb4 C#4/Db4
632
+ # 0 0 1
633
+ if len(
634
+ note_in_this_word
635
+ ) > 1: # is_slur = True, we should repeat the YUNMU to match the 2nd, 3rd... notes.
636
+ for idx in range(1, len(note_in_this_word)):
637
+ ph_lst.append(ph_in_this_word[-1])
638
+ note_lst.append(note_in_this_word[idx])
639
+ midi_dur_lst.append(midi_dur_in_this_word[idx])
640
+ is_slur.append(1)
641
+ ph_seq = ' '.join(ph_lst)
642
+
643
+ if len(ph_lst) == len(note_lst) == len(midi_dur_lst):
644
+ print(len(ph_lst), len(note_lst), len(midi_dur_lst))
645
+ print('Pass word-notes check.')
646
+ else:
647
+ print(
648
+ 'The number of words does\'t match the number of notes\' windows. ',
649
+ 'You should split the note(s) for each word by | mark.'
650
+ )
651
+ return None
652
+ return ph_seq, note_lst, midi_dur_lst, is_slur
653
+
654
+ def preprocess_phoneme_level_input(self, inp):
655
+ ph_seq = inp['ph_seq']
656
+ note_lst = inp['note_seq'].split()
657
+ midi_dur_lst = inp['note_dur_seq'].split()
658
+ is_slur = [float(x) for x in inp['is_slur_seq'].split()]
659
+ print(len(note_lst), len(ph_seq.split()), len(midi_dur_lst))
660
+ if len(note_lst) == len(ph_seq.split()) == len(midi_dur_lst):
661
+ print('Pass word-notes check.')
662
+ else:
663
+ print(
664
+ 'The number of words does\'t match the number of notes\' windows. ',
665
+ 'You should split the note(s) for each word by | mark.'
666
+ )
667
+ return None
668
+ return ph_seq, note_lst, midi_dur_lst, is_slur
669
+
670
+ def preprocess_input(self, inp, input_type='word'):
671
+ """
672
+
673
+ :param inp: {'text': str, 'item_name': (str, optional), 'spk_name': (str, optional)}
674
+ :return:
675
+ """
676
+
677
+ # item_name = inp.get('item_name', '<ITEM_NAME>')
678
+ spk_name = inp.get('spk_name', 'Alto-1')
679
+
680
+ # single spk
681
+ spk_id = self.spk_map[spk_name]
682
+
683
+ # get ph seq, note lst, midi dur lst, is slur lst.
684
+ if input_type == 'word':
685
+ ret = self.preprocess_word_level_input(inp)
686
+ elif input_type == 'phoneme':
687
+ ret = self.preprocess_phoneme_level_input(inp)
688
+ else:
689
+ print('Invalid input type.')
690
+ return None
691
+
692
+ if ret:
693
+ ph_seq, note_lst, midi_dur_lst, is_slur = ret
694
+ else:
695
+ print(
696
+ '==========> Preprocess_word_level or phone_level input wrong.'
697
+ )
698
+ return None
699
+
700
+ # convert note lst to midi id; convert note dur lst to midi duration
701
+ try:
702
+ midis = [
703
+ librosa.note_to_midi(x.split("/")[0]) if x != 'rest' else 0
704
+ for x in note_lst
705
+ ]
706
+ midi_dur_lst = [float(x) for x in midi_dur_lst]
707
+ except Exception as e:
708
+ print(e)
709
+ print('Invalid Input Type.')
710
+ return None
711
+
712
+ # ph_token = self.ph_encoder.encode(ph_seq)
713
+ item = {
714
+ # 'text': inp['text'],
715
+ 'phoneme': ph_seq,
716
+ 'spk': spk_id,
717
+ 'midi': np.asarray(midis),
718
+ 'midi_duration': np.asarray(midi_dur_lst),
719
+ 'is_slur': np.asarray(is_slur),
720
+ }
721
+ return item
utils/general.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ from typing import Union, Dict
4
+ from pathlib import Path
5
+ import os
6
+
7
+ from pathlib import Path
8
+
9
+ MAX_FILE_NAME_LENGTH = 100
10
+ TASK2DATASET_CLASS = {
11
+ 't2a': "TextToAudioDataset",
12
+ 't2m': "TextToMusicDataset",
13
+ 'se': "SpeechEnhancementDataset",
14
+ 'sr': "AudioSuperResolutionDataset",
15
+ 'v2a': "VideoToAudioDataset",
16
+ 'svs': "MidiSingingDataset",
17
+ 'tts': "TextToSpeechDataset"
18
+ }
19
+
20
+
21
+ def read_jsonl_to_mapping(
22
+ jsonl_file: Union[str, Path],
23
+ key_col: str,
24
+ value_col: str,
25
+ base_path=None
26
+ ) -> Dict[str, str]:
27
+ """
28
+ Read two columns, indicated by `key_col` and `value_col`, from the
29
+ given jsonl file to return the mapping dict
30
+ TODO handle duplicate keys
31
+ """
32
+ mapping = {}
33
+ with open(jsonl_file, 'r') as file:
34
+ for line in file.readlines():
35
+ data = json.loads(line.strip())
36
+ key = data[key_col]
37
+ value = data[value_col]
38
+ if base_path:
39
+ value = os.path.join(base_path, value)
40
+ mapping[key] = value
41
+ return mapping
42
+
43
+
44
+ def sanitize_filename(name: str, max_len: int = MAX_FILE_NAME_LENGTH) -> str:
45
+ """
46
+ Clean and truncate a string to make it a valid and safe filename.
47
+ """
48
+ name = re.sub(r'[\\/*?:"<>|]', '_', name)
49
+ name = name.replace('/', '_')
50
+ max_len = min(len(name), max_len)
51
+ return name[:max_len]
52
+
53
+
54
+ def transform_gen_fn_to_id(audio_file: Path, task: str) -> str:
55
+ if task == "svs":
56
+ audio_id = audio_file.stem.split("_")[0]
57
+ elif task == "sr":
58
+ audio_id = audio_file.stem
59
+ elif task == "tta":
60
+ audio_id = audio_file.stem[:11]
61
+ # audio_id = audio_file.stem[:12] + '.wav'
62
+ elif task == "ttm":
63
+ audio_id = audio_file.stem[:11]
64
+ # audio_id = audio_file.stem[:12] + '.wav'
65
+ elif task == "v2a":
66
+ audio_id = audio_file.stem.rsplit("_", 1)[0] + ".mp4"
67
+ else:
68
+ audio_id = audio_file.stem
69
+ return audio_id
70
+
71
+
72
+ def audio_dir_to_mapping(audio_dir: str | Path, task: str) -> dict:
73
+ mapping = {}
74
+ audio_dir = Path(audio_dir)
75
+ audio_files = sorted(audio_dir.iterdir())
76
+ for audio_file in audio_files:
77
+ audio_id = transform_gen_fn_to_id(audio_file, task)
78
+ mapping[audio_id] = str(audio_file.resolve())
79
+ return mapping
utils/logging.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from dataclasses import dataclass
3
+ import logging
4
+
5
+
6
+ @dataclass
7
+ class LoggingLogger:
8
+
9
+ filename: str | Path
10
+ level: str = "INFO"
11
+
12
+ def create_instance(self, ):
13
+ filename = self.filename.__str__()
14
+ formatter = logging.Formatter("[%(asctime)s] - %(message)s")
15
+
16
+ logger = logging.getLogger(__name__ + "." + filename)
17
+ logger.setLevel(getattr(logging, self.level))
18
+
19
+ file_handler = logging.FileHandler(filename)
20
+ file_handler.setFormatter(formatter)
21
+ logger.addHandler(file_handler)
22
+
23
+ return logger
utils/lr_scheduler_utilities.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+ import math
3
+ import copy
4
+ from torch.utils.data import DataLoader
5
+
6
+
7
+ def get_warmup_steps(
8
+ dataloader_one_pass_outside_steps: int,
9
+ warmup_steps: int | None = None,
10
+ warmup_epochs: float | None = None,
11
+ epoch_length: int | None = None,
12
+ ) -> int:
13
+ """
14
+ Derive warmup steps according to step number or epoch number.
15
+ If `warmup_steps` is provided, then just return it. Otherwise, derive
16
+ the warmup steps by epoch length and warmup epoch number.
17
+ """
18
+ if warmup_steps is not None:
19
+ return warmup_steps
20
+ else:
21
+ if epoch_length is None:
22
+ epoch_length = dataloader_one_pass_outside_steps
23
+ assert warmup_epochs is not None, "warmup_steps and warmup_epochs cannot be both None"
24
+ return int(epoch_length * warmup_epochs)
25
+
26
+
27
+ def get_dataloader_one_pass_outside_steps(
28
+ train_dataloader: DataLoader,
29
+ num_processes: int = 1,
30
+ ):
31
+ """
32
+ dataloader length after DDP, close to `original_length / gpu_number`
33
+ """
34
+ return math.ceil(len(train_dataloader) / num_processes)
35
+
36
+
37
+ def get_total_training_steps(
38
+ train_dataloader: DataLoader,
39
+ epochs: int,
40
+ num_processes: int = 1,
41
+ epoch_length: int | None = None
42
+ ):
43
+ """
44
+ Calculate the total number of "visible" training steps.
45
+
46
+ If `epoch_length` is provided, it is used as the fixed length for each epoch.
47
+ Otherwise, the function will determine the epoch length from `train_dataloader`.
48
+
49
+ Args:
50
+ train_dataloader:
51
+ Training dataloader object.
52
+ epochs:
53
+ The total number of epochs to run.
54
+ num_processes:
55
+ The number of parallel processes used for distributed training.
56
+ epoch_length:
57
+ A fixed number of training steps for each epoch. Defaults to None.
58
+
59
+ Returns:
60
+ int: The total number of training steps (i.e., `epochs * epoch_length`).
61
+ """
62
+ # `epoch_length` is not None: fixed length for each epoch
63
+ if epoch_length is None:
64
+ # `epoch_length` is the length of DDP-wrapped `train_dataloader`
65
+ epoch_length = get_dataloader_one_pass_outside_steps(
66
+ train_dataloader, num_processes
67
+ )
68
+ return epochs * epoch_length
69
+
70
+
71
+ def get_dataloader_one_pass_steps_inside_accelerator(
72
+ dataloader_one_pass_steps: int, gradient_accumulation_steps: int,
73
+ num_processes: int
74
+ ):
75
+ """
76
+ Calculate the number of "visible" training steps for a single pass over the dataloader
77
+ inside an accelerator, accounting for gradient accumulation and distributed training.
78
+
79
+
80
+ Args:
81
+ dataloader_one_pass_steps:
82
+ The number of steps (batches) in one pass over the dataset.
83
+ gradient_accumulation_steps:
84
+ The number of steps to accumulate gradients before performing a parameter update.
85
+ num_processes:
86
+ The number of parallel processes used for distributed training.
87
+
88
+ Returns:
89
+ int: The total number of "visible" training steps for one pass over the dataset,
90
+ multiplied by the number of processes.
91
+ """
92
+ return math.ceil(
93
+ dataloader_one_pass_steps / gradient_accumulation_steps
94
+ ) * num_processes
95
+
96
+
97
+ def get_steps_inside_accelerator_from_outside_steps(
98
+ outside_steps: int, dataloader_one_pass_outside_steps: int,
99
+ dataloader_one_pass_steps_inside_accelerator: int,
100
+ gradient_accumulation_steps: int, num_processes: int
101
+ ):
102
+ """
103
+ Convert "outside" steps (as observed in wandb logger or similar context)
104
+ to the corresponding number of "inside" steps (for accelerate lr scheduler).
105
+
106
+ Specifically, accelerate lr scheduler call `step()` `num_processes` times for
107
+ every `gradient_accumulation_steps` outside steps.
108
+
109
+ Args:
110
+ outside_steps:
111
+ The total number of steps counted outside accelerate context.
112
+ dataloader_one_pass_outside_steps:
113
+ The number of steps (batches) to complete one pass of the dataloader
114
+ outside accelerate.
115
+ dataloader_one_pass_steps_inside_accelerator:
116
+ The number of `lr_scheduler.step()` calls inside accelerate, calculated via
117
+ `get_dataloader_one_pass_steps_inside_accelerator`.
118
+ gradient_accumulation_steps:
119
+ The number of steps to accumulate gradients.
120
+ num_processes:
121
+ The number of parallel processes (GPUs) used in distributed training.
122
+
123
+ Returns:
124
+ int: The total number of `lr_scheduler.step()` calls inside accelerate that
125
+ correspond to the given `outside_steps`.
126
+ """
127
+ num_dataloader_epochs_passed = outside_steps // dataloader_one_pass_outside_steps
128
+ remaining_outside_steps = outside_steps % dataloader_one_pass_outside_steps
129
+ remaining_inside_accelerator_steps = (
130
+ remaining_outside_steps // gradient_accumulation_steps * num_processes
131
+ )
132
+ # accelerate scheduler call `step()` `num_processes` times every
133
+ # `gradient_accumulation_steps` steps:
134
+ # https://github.com/huggingface/accelerate/blob/main/src/accelerate/scheduler.py#L76
135
+ total_steps = (
136
+ num_dataloader_epochs_passed*
137
+ dataloader_one_pass_steps_inside_accelerator +
138
+ remaining_inside_accelerator_steps
139
+ )
140
+ return total_steps
141
+
142
+
143
+ def lr_scheduler_param_adapter(
144
+ config_dict: dict[str, Any], num_training_steps: int, num_warmup_steps: int
145
+ ) -> dict[str, Any]:
146
+ target_class = config_dict["_target_"]
147
+ return_dict = copy.deepcopy(config_dict)
148
+ if target_class == "transformers.get_scheduler":
149
+ return_dict.update({
150
+ "num_training_steps": num_training_steps,
151
+ "num_warmup_steps": num_warmup_steps
152
+ })
153
+
154
+ return return_dict
utils/torch_utilities.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ from typing import Callable
4
+ from pathlib import Path
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ logger = logging.Logger(__file__)
10
+
11
+
12
+ def remove_key_prefix_factory(prefix: str = "module."):
13
+ def func(
14
+ model_dict: dict[str, torch.Tensor], state_dict: dict[str,
15
+ torch.Tensor]
16
+ ) -> dict[str, torch.Tensor]:
17
+
18
+ state_dict = {
19
+ key[len(prefix):]: value
20
+ for key, value in state_dict.items() if key.startswith(prefix)
21
+ }
22
+ return state_dict
23
+
24
+ return func
25
+
26
+
27
+ def merge_matched_keys(
28
+ model_dict: dict[str, torch.Tensor], state_dict: dict[str, torch.Tensor]
29
+ ) -> dict[str, torch.Tensor]:
30
+ """
31
+ Args:
32
+ model_dict:
33
+ The state dict of the current model, which is going to load pretrained parameters
34
+ state_dict:
35
+ A dictionary of parameters from a pre-trained model.
36
+
37
+ Returns:
38
+ dict[str, torch.Tensor]:
39
+ The updated state dict, where parameters with matched keys and shape are
40
+ updated with values in `state_dict`.
41
+ """
42
+ pretrained_dict = {}
43
+ mismatch_keys = []
44
+ for key, value in state_dict.items():
45
+ if key in model_dict and model_dict[key].shape == value.shape:
46
+ pretrained_dict[key] = value
47
+ else:
48
+ mismatch_keys.append(key)
49
+ logger.info(
50
+ f"Loading pre-trained model, with mismatched keys {mismatch_keys}"
51
+ )
52
+ model_dict.update(pretrained_dict)
53
+ return model_dict
54
+
55
+
56
+ def load_pretrained_model(
57
+ model: nn.Module,
58
+ ckpt_or_state_dict: str | Path | dict[str, torch.Tensor],
59
+ state_dict_process_fn: Callable = merge_matched_keys
60
+ ) -> None:
61
+ state_dict = ckpt_or_state_dict
62
+ if not isinstance(state_dict, dict):
63
+ state_dict = torch.load(ckpt_or_state_dict, "cpu")
64
+
65
+ model_dict = model.state_dict()
66
+ state_dict = state_dict_process_fn(model_dict, state_dict)
67
+ model.load_state_dict(state_dict)
68
+
69
+
70
+ def create_mask_from_length(
71
+ lengths: torch.Tensor, max_length: int | None = None
72
+ ):
73
+ if max_length is None:
74
+ max_length = max(lengths)
75
+ idxs = torch.arange(max_length).reshape(1, -1) # (1, max_length)
76
+ mask = idxs.to(lengths.device) < lengths.view(-1, 1)
77
+ # (1, max_length) < (batch_size, 1) -> (batch_size, max_length)
78
+ return mask
79
+
80
+
81
+ def loss_with_mask(
82
+ loss: torch.Tensor,
83
+ mask: torch.Tensor,
84
+ reduce: bool = True
85
+ ) -> torch.Tensor:
86
+ """
87
+ Apply a mask to the loss tensor and optionally reduce it.
88
+
89
+ Args:
90
+ loss: Tensor of shape (b, t, ...) representing the loss values.
91
+ mask: Tensor of shape (b, t) where 1 indicates valid positions and 0 indicates masked positions.
92
+ reduce: If True, return a single scalar value; otherwise, return a tensor of shape (b,).
93
+
94
+ Returns:
95
+ torch.Tensor: A scalar if reduce is True, otherwise a tensor of shape (b,).
96
+ """
97
+ expanded_mask = mask[(..., ) + (None, ) * (loss.ndim - mask.ndim)]
98
+ expanded_mask = expanded_mask.expand_as(loss)
99
+ masked_loss = loss * expanded_mask
100
+
101
+ sum_dims = tuple(range(1, loss.ndim))
102
+ loss_sum = masked_loss.sum(dim=sum_dims)
103
+ mask_sum = expanded_mask.sum(dim=sum_dims)
104
+ loss = loss_sum / mask_sum
105
+
106
+ if reduce:
107
+ return loss.mean()
108
+ else:
109
+ return loss
110
+
111
+
112
+ def convert_pad_shape(pad_shape: list[list[int]]):
113
+ l = pad_shape[::-1]
114
+ pad_shape = [item for sublist in l for item in sublist]
115
+ return pad_shape
116
+
117
+
118
+ def create_alignment_path(duration: torch.Tensor, mask: torch.Tensor):
119
+ device = duration.device
120
+
121
+ b, t_x, t_y = mask.shape
122
+ cum_duration = torch.cumsum(duration, 1)
123
+
124
+ cum_duration_flat = cum_duration.view(b * t_x)
125
+ path = create_mask_from_length(cum_duration_flat, t_y).float()
126
+ path = path.view(b, t_x, t_y)
127
+ # take the diff on the `t_x` axis
128
+ path = path - torch.nn.functional.pad(
129
+ path, convert_pad_shape([[0, 0], [1, 0], [0, 0]])
130
+ )[:, :-1]
131
+ path = path * mask
132
+ return path
133
+
134
+
135
+ def trim_or_pad_length(x: torch.Tensor, target_length: int, length_dim: int):
136
+ """
137
+ Adjusts the size of the specified dimension of tensor x to match `target_length`.
138
+
139
+ Args:
140
+ x:
141
+ Input tensor.
142
+ target_length:
143
+ Desired size of the specified dimension.
144
+ length_dim:
145
+ The dimension to modify.
146
+
147
+ Returns:
148
+ torch.Tensor: The adjusted tensor.
149
+ """
150
+ current_length = x.shape[length_dim]
151
+
152
+ if current_length > target_length:
153
+ # Truncate the tensor
154
+ slices = [slice(None)] * x.ndim
155
+ slices[length_dim] = slice(0, target_length)
156
+ return x[tuple(slices)]
157
+
158
+ elif current_length < target_length:
159
+ # Pad the tensor with zeros
160
+ pad_shape = list(x.shape)
161
+ pad_length = target_length - current_length
162
+
163
+ pad_shape[length_dim] = pad_length # Shape for left padding
164
+ padding = torch.zeros(pad_shape, dtype=x.dtype, device=x.device)
165
+
166
+ return torch.cat([x, padding], dim=length_dim)
167
+
168
+ return x
169
+
170
+
171
+ def concat_non_padding(
172
+ seq1: torch.Tensor, mask1: torch.BoolTensor, seq2: torch.Tensor,
173
+ mask2: torch.BoolTensor
174
+ ):
175
+ """
176
+ Args
177
+ seq1 : Tensor (B, L1, E)
178
+ First sequence.
179
+ mask1 : BoolTensor (B, L1)
180
+ True for valid tokens in seq1, False for padding.
181
+ seq2 : Tensor (B, L2, E)
182
+ Second sequence.
183
+ mask2 : BoolTensor (B, L2)
184
+ True for valid tokens in seq2, False for padding.
185
+
186
+ Returns
187
+ concat_seq : Tensor (B, L1+L2, E)
188
+ Both sequences concatenated; valid tokens are left-aligned,
189
+ padding on the right is 0.
190
+ concat_mask: BoolTensor (B, L1+L2)
191
+ Mask for the concatenated sequence.
192
+ perm : LongTensor (B, L1+L2)
193
+ Permutation that maps **original indices → new indices**.
194
+ Needed for restoring the original sequences.
195
+ """
196
+ mask1, mask2 = mask1.bool(), mask2.bool()
197
+ B, L1, E = seq1.shape
198
+ L2 = seq2.size(1)
199
+ L = L1 + L2
200
+
201
+ seq_cat = torch.cat([seq1, seq2], dim=1) # (B, L, E)
202
+ mask_cat = torch.cat([mask1, mask2], dim=1) # (B, L)
203
+
204
+ # ----- Key step: stable sort so that all valid tokens move to the left -----
205
+ # Padding positions get +L, guaranteeing the largest “score” → sorted to the end.
206
+ positions = torch.arange(L, device=seq_cat.device).unsqueeze(0) # (1, L)
207
+ sort_score = positions + (~mask_cat) * L
208
+ perm = sort_score.argsort(dim=1, stable=True) # (B, L)
209
+
210
+ # Build concatenated sequence & mask
211
+ gather_idx = perm.unsqueeze(-1).expand(-1, -1, E) # (B, L, E)
212
+ concat_seq = seq_cat.gather(1, gather_idx)
213
+ concat_mask = mask_cat.gather(1, perm)
214
+
215
+ # Explicitly zero out the right-hand padding region for safety
216
+ concat_seq = concat_seq * concat_mask.unsqueeze(-1)
217
+
218
+ return concat_seq, concat_mask, perm
219
+
220
+
221
+ def restore_from_concat(
222
+ concat_seq: torch.Tensor, mask1: torch.BoolTensor, mask2: torch.BoolTensor,
223
+ perm: torch.LongTensor
224
+ ):
225
+ """
226
+ Restore (seq1, seq2) from the concatenated sequence produced by
227
+ `concat_non_padding`, using the returned permutation `perm`.
228
+ Fully vectorised — no Python loops.
229
+ """
230
+ mask1, mask2 = mask1.bool(), mask2.bool()
231
+ B, L1 = mask1.shape
232
+ L2 = mask2.size(1)
233
+ E = concat_seq.size(-1)
234
+
235
+ # Inverse permutation: maps **new_idx → old_idx**
236
+ inv_perm = torch.empty_like(perm)
237
+ inv_perm.scatter_(
238
+ 1, perm,
239
+ torch.arange(L1 + L2, device=perm.device).unsqueeze(0).expand(B, -1)
240
+ )
241
+
242
+ # Bring tokens back to their original order
243
+ gather_idx = inv_perm.unsqueeze(-1).expand(-1, -1, E)
244
+ seq_cat_rec = concat_seq.gather(1, gather_idx) # (B, L1+L2, E)
245
+
246
+ # Split back into the two sequences and mask out padding positions
247
+ seq1_restore, seq2_restore = seq_cat_rec.split([L1, L2], dim=1)
248
+ seq1_restore = seq1_restore * mask1.unsqueeze(-1)
249
+ seq2_restore = seq2_restore * mask2.unsqueeze(-1)
250
+
251
+ return seq1_restore, seq2_restore
252
+
253
+
254
+ def contains_nan(data):
255
+ """check if data contains NaN"""
256
+ if isinstance(data, torch.Tensor):
257
+ return torch.isnan(data).any().item()
258
+ elif isinstance(data, np.ndarray):
259
+ return np.isnan(data).any()
260
+ elif isinstance(data, float):
261
+ return math.isnan(data)
262
+ elif isinstance(data, (list, tuple)):
263
+ return any(contains_nan(x) for x in data)
264
+ elif isinstance(data, dict):
265
+ return any(contains_nan(v) for v in data.values())
266
+ return False
267
+
268
+
269
+ def check_nan_in_batch(batch):
270
+ """check if batch contains NaN and return nan audio ids"""
271
+ assert type(batch) == dict, "batch type error"
272
+ nan_audio_ids = []
273
+ audio_ids = batch["audio_id"]
274
+ audio_id2content = {}
275
+ for idx, audio_id in enumerate(audio_ids):
276
+ content = []
277
+ for k, v in batch.items():
278
+ if k == "audio_id":
279
+ continue
280
+ content.append(v[idx])
281
+ audio_id2content[audio_id] = content
282
+
283
+ for audio_id, content in audio_id2content.items():
284
+ if contains_nan(content):
285
+ nan_audio_ids.append(audio_id)
286
+ print(f"{audio_id} contains NaN")
287
+ return nan_audio_ids
utils/video.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import os
3
+ from typing import Callable
4
+ import tempfile
5
+
6
+ import numpy as np
7
+ import soundfile as sf
8
+ from moviepy import VideoFileClip, AudioFileClip
9
+ from moviepy.audio.AudioClip import AudioArrayClip
10
+ from moviepy.audio.fx import AudioLoop
11
+ import torch
12
+ import torchvision
13
+
14
+
15
+ def merge_audio_video(
16
+ audio: str | Path | np.ndarray,
17
+ video_path: str | Path,
18
+ target_path: str | Path,
19
+ backend: str = "moviepy",
20
+ logging: bool = False,
21
+ audio_fps: int | None = None
22
+ ):
23
+ """
24
+ Merge audio and video into a single file.
25
+
26
+ Args:
27
+ audio_path (str | Path): Path to the audio file.
28
+ video_path (str | Path): Path to the video file.
29
+ target_path (str | Path): Path to the target file.
30
+ backend (str, optional): The backend to use for merging. Defaults to "moviepy".
31
+ """
32
+ assert backend in [
33
+ "moviepy", "ffmpeg"
34
+ ], "Backend should be moviepy or ffmpeg"
35
+ if backend == "moviepy":
36
+ video = VideoFileClip(video_path.__str__())
37
+ video = video.without_audio()
38
+ if isinstance(audio, np.ndarray):
39
+ assert audio_fps is not None
40
+ # write to a temp file, then use AudioFileClip to load
41
+ with tempfile.NamedTemporaryFile(
42
+ suffix=".wav", delete=False
43
+ ) as tmp_wav:
44
+ sf.write(tmp_wav.name, audio, samplerate=audio_fps)
45
+ audio = AudioFileClip(tmp_wav.name)
46
+ else:
47
+ audio = AudioFileClip(audio.__str__())
48
+ tmp_wav = None
49
+
50
+ video = video.with_audio(audio)
51
+
52
+ target_path = Path(target_path)
53
+ video.write_videofile(
54
+ target_path,
55
+ logger=None if not logging else "bar",
56
+ threads=8,
57
+ preset="ultrafast",
58
+ ffmpeg_params=["-crf", "23"]
59
+ )
60
+ if tmp_wav:
61
+ os.remove(tmp_wav.name)
62
+ else:
63
+ logging_arg = "" if logging else "-loglevel quiet"
64
+ command = f"ffmpeg {logging_arg} -i '{video_path.__str__()}' -i '{audio.__str__()}' -c:v copy " \
65
+ f"-c:a copy -map 0:v:0 -map 1:a:0 '{target_path.__str__()}'"
66
+ os.system(command)
67
+
68
+
69
+ def read_video_frames(
70
+ video_path: str,
71
+ duration: float | None = 10.0,
72
+ fps: int = 10,
73
+ video_size: tuple[int] = (256, 256),
74
+ resize_transform: Callable | None = None,
75
+ ):
76
+ try:
77
+ video, _, meta = torchvision.io.read_video(
78
+ str(video_path), start_pts=0, end_pts=duration, pts_unit='sec'
79
+ )
80
+ video_duration = video.shape[0] / meta["video_fps"]
81
+
82
+ if duration and video_duration < duration:
83
+ num_frames, height, width, channels = video.shape
84
+ padding_length = int(duration * meta["video_fps"]) - num_frames
85
+ padding = torch.zeros((padding_length, height, width, channels),
86
+ dtype=video.dtype)
87
+ video = torch.cat([video, padding], dim=0)
88
+ target_length = int(duration * fps)
89
+ else:
90
+ target_length = int(video_duration * fps)
91
+
92
+ indices = torch.linspace(0, video.shape[0] - 1,
93
+ steps=target_length).long()
94
+ video = video[indices]
95
+ video = video.permute(0, 3, 1, 2) # [T, C, H, W]
96
+ if resize_transform is None:
97
+ resize_transform = torchvision.transforms.Resize(video_size)
98
+ video = resize_transform(video)
99
+ return video
100
+ except Exception as e:
101
+ print(f"error reading video {video_path}: {e}")
102
+ assert duration is not None
103
+ target_length = int(duration * fps)
104
+ return torch.zeros(target_length, 3, *video_size)