Spaces:
Runtime error
Runtime error
Stable Audio Open + progbars + mp3 + batched forward + cleanup
Browse files- .gitattributes +1 -0
- Examples/{Beethoven.wav → Beethoven.mp3} +2 -2
- Examples/{Cat_dog.wav → Beethoven_arcade.mp3} +2 -2
- Examples/{Beethoven_arcade.wav → Beethoven_piano.mp3} +2 -2
- Examples/{Beethoven_piano.wav → Beethoven_rock.mp3} +2 -2
- Examples/{Cat.wav → Cat.mp3} +2 -2
- Examples/Cat_dog.mp3 +3 -0
- Examples/ModalJazz.mp3 +3 -0
- Examples/ModalJazz.wav +0 -3
- Examples/ModalJazz_banjo.mp3 +3 -0
- Examples/ModalJazz_banjo.wav +0 -3
- Examples/Shadows.mp3 +3 -0
- Examples/Shadows_arcade.mp3 +3 -0
- README.md +4 -1
- app.py +235 -158
- inversion_utils.py +139 -381
- models.py +469 -253
- requirements.txt +3 -2
- utils.py +50 -16
.gitattributes
CHANGED
|
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
*.wav filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
*.wav filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.mp3 filter=lfs diff=lfs merge=lfs -text
|
Examples/{Beethoven.wav → Beethoven.mp3}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3dcc79fe071d118df3caaeeb85d7944f93a5df40bbdb72a26b67bd57da2af7c5
|
| 3 |
+
size 1097142
|
Examples/{Cat_dog.wav → Beethoven_arcade.mp3}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:542bd61d9cc1723ccfd9bfc06b0818e77fc763013827ff1f9289e2ac6a912904
|
| 3 |
+
size 563040
|
Examples/{Beethoven_arcade.wav → Beethoven_piano.mp3}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:000d82c39d8c41b10188d328e29cb1baa948232bacd693f22e297cc54f4bb707
|
| 3 |
+
size 563040
|
Examples/{Beethoven_piano.wav → Beethoven_rock.mp3}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c51d75c9094a50c7892449a013b32ffde266a5abd6dad9f00bf3aeec0ee935ee
|
| 3 |
+
size 1097142
|
Examples/{Cat.wav → Cat.mp3}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cff7010e5fb12a57508c7a0941663f1a12bfc8b3b3d01d0973359cd42ae5eb1e
|
| 3 |
+
size 402542
|
Examples/Cat_dog.mp3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:72ff727243606215c934552e946f7d97b5e2e39c4d6263f7f36659e3f39f3008
|
| 3 |
+
size 207403
|
Examples/ModalJazz.mp3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:34cf145b84b6b4669050ca42932fb74ac0f28aabbe6c665f12a877c9809fa9c6
|
| 3 |
+
size 4153468
|
Examples/ModalJazz.wav
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:846a77046d21ebc3996841404eede9d56797c82b3414025e1ccafe586eaf2959
|
| 3 |
-
size 9153322
|
|
|
|
|
|
|
|
|
|
|
|
Examples/ModalJazz_banjo.mp3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:11680068427556981aa6304e6c11bd05debc820ca581c248954c1ffe3cd94569
|
| 3 |
+
size 2128320
|
Examples/ModalJazz_banjo.wav
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:122e0078c0bf2fc96425071706fe0e8674c93cc1d2787fd02c0e2c0f12de5cc5
|
| 3 |
-
size 6802106
|
|
|
|
|
|
|
|
|
|
|
|
Examples/Shadows.mp3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2e0cab2ebda4507641d6a1b5d9b2d888a7526581b7de48540ebf86ce00579908
|
| 3 |
+
size 1342693
|
Examples/Shadows_arcade.mp3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:68c84805ea17d0697cd79bc85394754d70fb02f740db4bee4c6ccbb5269a5d84
|
| 3 |
+
size 1342693
|
README.md
CHANGED
|
@@ -9,7 +9,10 @@ app_file: app.py
|
|
| 9 |
pinned: false
|
| 10 |
license: cc-by-sa-4.0
|
| 11 |
short_description: Edit audios with text prompts
|
|
|
|
|
|
|
|
|
|
| 12 |
---
|
| 13 |
|
| 14 |
The 30-second limit was introduced to ensure that queue wait times remain reasonable, especially when there are a lot of users.
|
| 15 |
-
For that reason pull-requests that change this limit will not be merged. Please clone or duplicate the space to work locally without limits.
|
|
|
|
| 9 |
pinned: false
|
| 10 |
license: cc-by-sa-4.0
|
| 11 |
short_description: Edit audios with text prompts
|
| 12 |
+
hf_oauth: true
|
| 13 |
+
hf_oauth_scopes:
|
| 14 |
+
- read-repos
|
| 15 |
---
|
| 16 |
|
| 17 |
The 30-second limit was introduced to ensure that queue wait times remain reasonable, especially when there are a lot of users.
|
| 18 |
+
For that reason pull-requests that change this limit will not be merged. Please clone or duplicate the space to work locally without limits.
|
app.py
CHANGED
|
@@ -6,27 +6,26 @@ if os.getenv('SPACES_ZERO_GPU') == "true":
|
|
| 6 |
import gradio as gr
|
| 7 |
import random
|
| 8 |
import torch
|
|
|
|
| 9 |
from torch import inference_mode
|
| 10 |
-
|
| 11 |
-
from typing import Optional
|
| 12 |
import numpy as np
|
| 13 |
from models import load_model
|
| 14 |
import utils
|
| 15 |
import spaces
|
|
|
|
| 16 |
from inversion_utils import inversion_forward_process, inversion_reverse_process
|
| 17 |
|
| 18 |
|
| 19 |
-
# current_loaded_model = "cvssp/audioldm2-music"
|
| 20 |
-
# # current_loaded_model = "cvssp/audioldm2-music"
|
| 21 |
-
|
| 22 |
-
# ldm_stable = load_model(current_loaded_model, device, 200) # deafult model
|
| 23 |
LDM2 = "cvssp/audioldm2"
|
| 24 |
MUSIC = "cvssp/audioldm2-music"
|
| 25 |
LDM2_LARGE = "cvssp/audioldm2-large"
|
|
|
|
| 26 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 27 |
ldm2 = load_model(model_id=LDM2, device=device)
|
| 28 |
ldm2_large = load_model(model_id=LDM2_LARGE, device=device)
|
| 29 |
ldm2_music = load_model(model_id=MUSIC, device=device)
|
|
|
|
| 30 |
|
| 31 |
|
| 32 |
def randomize_seed_fn(seed, randomize_seed):
|
|
@@ -36,89 +35,136 @@ def randomize_seed_fn(seed, randomize_seed):
|
|
| 36 |
return seed
|
| 37 |
|
| 38 |
|
| 39 |
-
def invert(ldm_stable, x0, prompt_src, num_diffusion_steps, cfg_scale_src
|
| 40 |
# ldm_stable.model.scheduler.set_timesteps(num_diffusion_steps, device=device)
|
| 41 |
|
| 42 |
with inference_mode():
|
| 43 |
w0 = ldm_stable.vae_encode(x0)
|
| 44 |
|
| 45 |
# find Zs and wts - forward process
|
| 46 |
-
_, zs, wts = inversion_forward_process(ldm_stable, w0, etas=1,
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
|
|
|
| 53 |
|
| 54 |
|
| 55 |
-
def sample(ldm_stable, zs, wts,
|
| 56 |
# reverse process (via Zs and wT)
|
| 57 |
tstart = torch.tensor(tstart, dtype=torch.int)
|
| 58 |
-
|
| 59 |
-
w0, _ = inversion_reverse_process(ldm_stable, xT=wts, skips=steps - skip,
|
| 60 |
etas=1., prompts=[prompt_tar],
|
| 61 |
neg_prompts=[""], cfg_scales=[cfg_scale_tar],
|
| 62 |
-
|
| 63 |
-
|
|
|
|
|
|
|
| 64 |
|
| 65 |
# vae decode image
|
| 66 |
with inference_mode():
|
| 67 |
x0_dec = ldm_stable.vae_decode(w0)
|
| 68 |
-
if x0_dec.dim() < 4:
|
| 69 |
-
x0_dec = x0_dec[None, :, :, :]
|
| 70 |
|
| 71 |
-
|
| 72 |
-
|
|
|
|
| 73 |
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
if model_id == LDM2:
|
| 81 |
-
factor =
|
| 82 |
elif model_id == LDM2_LARGE:
|
| 83 |
-
factor =
|
|
|
|
|
|
|
| 84 |
else: # MUSIC
|
| 85 |
factor = 1
|
| 86 |
|
| 87 |
-
|
| 88 |
if do_inversion or randomize_seed:
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
|
|
|
|
|
|
| 91 |
if input_audio is None:
|
| 92 |
raise gr.Error('Input audio missing!')
|
| 93 |
-
duration = min(utils.get_duration(input_audio), 30)
|
| 94 |
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
-
@spaces.GPU(duration=get_duration)
|
| 101 |
-
def edit(
|
| 102 |
-
# cache_dir,
|
| 103 |
-
input_audio,
|
| 104 |
-
model_id: str,
|
| 105 |
-
do_inversion: bool,
|
| 106 |
-
# wtszs_file: str,
|
| 107 |
-
wts: Optional[torch.Tensor], zs: Optional[torch.Tensor],
|
| 108 |
-
saved_inv_model: str,
|
| 109 |
-
source_prompt="",
|
| 110 |
-
target_prompt="",
|
| 111 |
-
steps=200,
|
| 112 |
-
cfg_scale_src=3.5,
|
| 113 |
-
cfg_scale_tar=12,
|
| 114 |
-
t_start=45,
|
| 115 |
-
randomize_seed=True):
|
| 116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
print(model_id)
|
| 118 |
if model_id == LDM2:
|
| 119 |
ldm_stable = ldm2
|
| 120 |
elif model_id == LDM2_LARGE:
|
| 121 |
ldm_stable = ldm2_large
|
|
|
|
|
|
|
| 122 |
else: # MUSIC
|
| 123 |
ldm_stable = ldm2_music
|
| 124 |
|
|
@@ -130,102 +176,126 @@ def edit(
|
|
| 130 |
|
| 131 |
if input_audio is None:
|
| 132 |
raise gr.Error('Input audio missing!')
|
| 133 |
-
x0 = utils.load_audio(input_audio, ldm_stable.get_fn_STFT(), device=device
|
| 134 |
-
|
| 135 |
-
# if not (do_inversion or randomize_seed):
|
| 136 |
-
# if not os.path.exists(wtszs_file):
|
| 137 |
-
# do_inversion = True
|
| 138 |
-
# Too much time has passed
|
| 139 |
if wts is None or zs is None:
|
| 140 |
do_inversion = True
|
| 141 |
|
| 142 |
if do_inversion or randomize_seed: # always re-run inversion
|
| 143 |
-
zs_tensor, wts_tensor = invert(ldm_stable=ldm_stable, x0=x0, prompt_src=source_prompt,
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
# wtszs_file = f.name
|
| 149 |
-
# wtszs_file = gr.State(value=f.name)
|
| 150 |
-
# wts = gr.State(value=wts_tensor)
|
| 151 |
wts = wts_tensor
|
| 152 |
zs = zs_tensor
|
| 153 |
-
|
| 154 |
-
# demo.move_resource_to_block_cache(f.name)
|
| 155 |
saved_inv_model = model_id
|
| 156 |
do_inversion = False
|
| 157 |
else:
|
| 158 |
-
# wtszs = torch.load(wtszs_file, map_location=device)
|
| 159 |
-
# # wtszs = torch.load(wtszs_file.f, map_location=device)
|
| 160 |
-
# wts_tensor = wtszs['wts']
|
| 161 |
-
# zs_tensor = wtszs['zs']
|
| 162 |
wts_tensor = wts.to(device)
|
| 163 |
zs_tensor = zs.to(device)
|
|
|
|
| 164 |
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
output = sample(ldm_stable, zs_tensor, wts_tensor, steps, prompt_tar=target_prompt,
|
| 169 |
-
tstart=int(t_start / 100 * steps), cfg_scale_tar=cfg_scale_tar)
|
| 170 |
|
| 171 |
-
return output, wts.cpu(), zs.cpu(), saved_inv_model, do_inversion
|
| 172 |
# return output, wtszs_file, saved_inv_model, do_inversion
|
| 173 |
|
| 174 |
|
| 175 |
def get_example():
|
| 176 |
case = [
|
| 177 |
-
['Examples/Beethoven.
|
| 178 |
'',
|
| 179 |
'A recording of an arcade game soundtrack.',
|
| 180 |
45,
|
| 181 |
'cvssp/audioldm2-music',
|
| 182 |
'27s',
|
| 183 |
-
'Examples/Beethoven_arcade.
|
| 184 |
],
|
| 185 |
-
['Examples/Beethoven.
|
| 186 |
'A high quality recording of wind instruments and strings playing.',
|
| 187 |
'A high quality recording of a piano playing.',
|
| 188 |
45,
|
| 189 |
'cvssp/audioldm2-music',
|
| 190 |
'27s',
|
| 191 |
-
'Examples/Beethoven_piano.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
],
|
| 193 |
-
['Examples/ModalJazz.
|
| 194 |
'Trumpets playing alongside a piano, bass and drums in an upbeat old-timey cool jazz song.',
|
| 195 |
'A banjo playing alongside a piano, bass and drums in an upbeat old-timey cool country song.',
|
| 196 |
45,
|
| 197 |
'cvssp/audioldm2-music',
|
| 198 |
'106s',
|
| 199 |
-
'Examples/ModalJazz_banjo.
|
| 200 |
-
['Examples/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
'',
|
| 202 |
'A dog barking.',
|
| 203 |
75,
|
| 204 |
'cvssp/audioldm2-large',
|
| 205 |
'10s',
|
| 206 |
-
'Examples/Cat_dog.
|
| 207 |
]
|
| 208 |
return case
|
| 209 |
|
| 210 |
|
| 211 |
intro = """
|
| 212 |
-
<h1 style="font-weight:
|
| 213 |
-
<h2 style="font-weight:
|
| 214 |
-
|
|
|
|
| 215 |
<a href="https://arxiv.org/abs/2402.10009">[Paper]</a> |
|
| 216 |
<a href="https://hilamanor.github.io/AudioEditing/">[Project page]</a> |
|
| 217 |
<a href="https://github.com/HilaManor/AudioEditingCode">[Code]</a>
|
| 218 |
</h3>
|
| 219 |
|
| 220 |
-
|
| 221 |
-
<p style="font-size: 0.9rem; margin: 0rem; line-height: 1.2em; margin-top:1em">
|
| 222 |
For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
|
| 223 |
<a href="https://huggingface.co/spaces/hilamanor/audioEditing?duplicate=true">
|
| 224 |
-
<img style="margin-top: 0em; margin-bottom: 0em; display:inline" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
</p>
|
| 226 |
-
|
| 227 |
"""
|
| 228 |
|
|
|
|
| 229 |
help = """
|
| 230 |
<div style="font-size:medium">
|
| 231 |
<b>Instructions:</b><br>
|
|
@@ -233,22 +303,27 @@ help = """
|
|
| 233 |
<li>You must provide an input audio and a target prompt to edit the audio. </li>
|
| 234 |
<li>T<sub>start</sub> is used to control the tradeoff between fidelity to the original signal and text-adhearance.
|
| 235 |
Lower value -> favor fidelity. Higher value -> apply a stronger edit.</li>
|
| 236 |
-
<li>Make sure that you use
|
| 237 |
-
For example, use
|
| 238 |
</li>
|
| 239 |
<li>You can additionally provide a source prompt to guide even further the editing process.</li>
|
| 240 |
<li>Longer input will take more time.</li>
|
| 241 |
<li><strong>Unlimited length</strong>: This space automatically trims input audio to a maximum length of 30 seconds.
|
| 242 |
-
For unlimited length, duplicated the space, and
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
<code style="display:inline; background-color: lightgrey;
|
|
|
|
| 246 |
</ul>
|
| 247 |
</div>
|
| 248 |
|
| 249 |
"""
|
| 250 |
|
| 251 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
def reset_do_inversion(do_inversion_user, do_inversion):
|
| 253 |
# do_inversion = gr.State(value=True)
|
| 254 |
do_inversion = True
|
|
@@ -267,23 +342,22 @@ with gr.Blocks(css='style.css') as demo: #, delete_cache=(3600, 3600)) as demo:
|
|
| 267 |
return do_inversion_user, do_inversion
|
| 268 |
|
| 269 |
gr.HTML(intro)
|
|
|
|
| 270 |
wts = gr.State()
|
| 271 |
zs = gr.State()
|
| 272 |
-
|
| 273 |
-
# cache_dir = gr.State(demo.GRADIO_CACHE)
|
| 274 |
saved_inv_model = gr.State()
|
| 275 |
-
# current_loaded_model = gr.State(value="cvssp/audioldm2-music")
|
| 276 |
-
# ldm_stable = load_model("cvssp/audioldm2-music", device, 200)
|
| 277 |
-
# ldm_stable = gr.State(value=ldm_stable)
|
| 278 |
do_inversion = gr.State(value=True) # To save some runtime when editing the same thing over and over
|
| 279 |
do_inversion_user = gr.State(value=False)
|
| 280 |
|
| 281 |
with gr.Group():
|
| 282 |
-
gr.Markdown("💡 **note**: input longer than **30 sec** is automatically trimmed
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
|
|
|
|
|
|
| 287 |
|
| 288 |
with gr.Row():
|
| 289 |
tar_prompt = gr.Textbox(label="Prompt", info="Describe your desired edited output",
|
|
@@ -293,17 +367,16 @@ with gr.Blocks(css='style.css') as demo: #, delete_cache=(3600, 3600)) as demo:
|
|
| 293 |
with gr.Row():
|
| 294 |
t_start = gr.Slider(minimum=15, maximum=85, value=45, step=1, label="T-start (%)", interactive=True, scale=3,
|
| 295 |
info="Lower T-start -> closer to original audio. Higher T-start -> stronger edit.")
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
info="Choose a checkpoint suitable for your
|
| 302 |
value="cvssp/audioldm2-music", interactive=True, type="value", scale=2)
|
| 303 |
-
|
| 304 |
with gr.Row():
|
| 305 |
-
|
| 306 |
-
|
| 307 |
|
| 308 |
with gr.Accordion("More Options", open=False):
|
| 309 |
with gr.Row():
|
|
@@ -311,58 +384,62 @@ with gr.Blocks(css='style.css') as demo: #, delete_cache=(3600, 3600)) as demo:
|
|
| 311 |
info="Optional: Describe the original audio input",
|
| 312 |
placeholder="A recording of a happy upbeat classical music piece",)
|
| 313 |
|
| 314 |
-
with gr.Row():
|
| 315 |
cfg_scale_src = gr.Number(value=3, minimum=0.5, maximum=25, precision=None,
|
| 316 |
label="Source Guidance Scale", interactive=True, scale=1)
|
| 317 |
cfg_scale_tar = gr.Number(value=12, minimum=0.5, maximum=25, precision=None,
|
| 318 |
label="Target Guidance Scale", interactive=True, scale=1)
|
| 319 |
-
steps = gr.Number(value=50, step=1, minimum=
|
| 320 |
info="Higher values (e.g. 200) yield higher-quality generation.",
|
| 321 |
-
label="Num Diffusion Steps", interactive=True, scale=
|
| 322 |
-
with gr.Row():
|
| 323 |
seed = gr.Number(value=0, precision=0, label="Seed", interactive=True)
|
| 324 |
randomize_seed = gr.Checkbox(label='Randomize seed', value=False)
|
|
|
|
| 325 |
length = gr.Number(label="Length", interactive=False, visible=False)
|
| 326 |
|
| 327 |
with gr.Accordion("Help💡", open=False):
|
| 328 |
gr.HTML(help)
|
| 329 |
|
| 330 |
submit.click(
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
|
|
|
| 359 |
|
| 360 |
# If sources changed we have to rerun inversion
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
|
|
|
|
|
|
| 366 |
|
| 367 |
gr.Examples(
|
| 368 |
label="Examples",
|
|
|
|
| 6 |
import gradio as gr
|
| 7 |
import random
|
| 8 |
import torch
|
| 9 |
+
import os
|
| 10 |
from torch import inference_mode
|
| 11 |
+
from typing import Optional, List
|
|
|
|
| 12 |
import numpy as np
|
| 13 |
from models import load_model
|
| 14 |
import utils
|
| 15 |
import spaces
|
| 16 |
+
import huggingface_hub
|
| 17 |
from inversion_utils import inversion_forward_process, inversion_reverse_process
|
| 18 |
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
LDM2 = "cvssp/audioldm2"
|
| 21 |
MUSIC = "cvssp/audioldm2-music"
|
| 22 |
LDM2_LARGE = "cvssp/audioldm2-large"
|
| 23 |
+
STABLEAUD = "stabilityai/stable-audio-open-1.0"
|
| 24 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 25 |
ldm2 = load_model(model_id=LDM2, device=device)
|
| 26 |
ldm2_large = load_model(model_id=LDM2_LARGE, device=device)
|
| 27 |
ldm2_music = load_model(model_id=MUSIC, device=device)
|
| 28 |
+
ldm_stableaud = load_model(model_id=STABLEAUD, device=device, token=os.getenv('PRIV_TOKEN'))
|
| 29 |
|
| 30 |
|
| 31 |
def randomize_seed_fn(seed, randomize_seed):
|
|
|
|
| 35 |
return seed
|
| 36 |
|
| 37 |
|
| 38 |
+
def invert(ldm_stable, x0, prompt_src, num_diffusion_steps, cfg_scale_src, duration, save_compute):
|
| 39 |
# ldm_stable.model.scheduler.set_timesteps(num_diffusion_steps, device=device)
|
| 40 |
|
| 41 |
with inference_mode():
|
| 42 |
w0 = ldm_stable.vae_encode(x0)
|
| 43 |
|
| 44 |
# find Zs and wts - forward process
|
| 45 |
+
_, zs, wts, extra_info = inversion_forward_process(ldm_stable, w0, etas=1,
|
| 46 |
+
prompts=[prompt_src],
|
| 47 |
+
cfg_scales=[cfg_scale_src],
|
| 48 |
+
num_inference_steps=num_diffusion_steps,
|
| 49 |
+
numerical_fix=True,
|
| 50 |
+
duration=duration,
|
| 51 |
+
save_compute=save_compute)
|
| 52 |
+
return zs, wts, extra_info
|
| 53 |
|
| 54 |
|
| 55 |
+
def sample(ldm_stable, zs, wts, extra_info, prompt_tar, tstart, cfg_scale_tar, duration, save_compute):
|
| 56 |
# reverse process (via Zs and wT)
|
| 57 |
tstart = torch.tensor(tstart, dtype=torch.int)
|
| 58 |
+
w0, _ = inversion_reverse_process(ldm_stable, xT=wts, tstart=tstart,
|
|
|
|
| 59 |
etas=1., prompts=[prompt_tar],
|
| 60 |
neg_prompts=[""], cfg_scales=[cfg_scale_tar],
|
| 61 |
+
zs=zs[:int(tstart)],
|
| 62 |
+
duration=duration,
|
| 63 |
+
extra_info=extra_info,
|
| 64 |
+
save_compute=save_compute)
|
| 65 |
|
| 66 |
# vae decode image
|
| 67 |
with inference_mode():
|
| 68 |
x0_dec = ldm_stable.vae_decode(w0)
|
|
|
|
|
|
|
| 69 |
|
| 70 |
+
if 'stable-audio' not in ldm_stable.model_id:
|
| 71 |
+
if x0_dec.dim() < 4:
|
| 72 |
+
x0_dec = x0_dec[None, :, :, :]
|
| 73 |
|
| 74 |
+
with torch.no_grad():
|
| 75 |
+
audio = ldm_stable.decode_to_mel(x0_dec)
|
| 76 |
+
else:
|
| 77 |
+
audio = x0_dec.squeeze(0).T
|
| 78 |
+
|
| 79 |
+
return (ldm_stable.get_sr(), audio.squeeze().cpu().numpy())
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def get_duration(input_audio,
|
| 83 |
+
model_id: str,
|
| 84 |
+
do_inversion: bool,
|
| 85 |
+
wts: Optional[torch.Tensor], zs: Optional[torch.Tensor], extra_info: Optional[List],
|
| 86 |
+
saved_inv_model: str,
|
| 87 |
+
source_prompt: str = "",
|
| 88 |
+
target_prompt: str = "",
|
| 89 |
+
steps: int = 200,
|
| 90 |
+
cfg_scale_src: float = 3.5,
|
| 91 |
+
cfg_scale_tar: float = 12,
|
| 92 |
+
t_start: int = 45,
|
| 93 |
+
randomize_seed: bool = True,
|
| 94 |
+
save_compute: bool = True,
|
| 95 |
+
oauth_token: Optional[gr.OAuthToken] = None):
|
| 96 |
if model_id == LDM2:
|
| 97 |
+
factor = 1
|
| 98 |
elif model_id == LDM2_LARGE:
|
| 99 |
+
factor = 2.5
|
| 100 |
+
elif model_id == STABLEAUD:
|
| 101 |
+
factor = 3.2
|
| 102 |
else: # MUSIC
|
| 103 |
factor = 1
|
| 104 |
|
| 105 |
+
forwards = 0
|
| 106 |
if do_inversion or randomize_seed:
|
| 107 |
+
forwards = steps if source_prompt == "" else steps * 2 # x2 when there is a prompt text
|
| 108 |
+
forwards += int(t_start / 100 * steps) * 2
|
| 109 |
+
|
| 110 |
+
duration = min(utils.get_duration(input_audio), utils.MAX_DURATION)
|
| 111 |
+
time_for_maxlength = factor * forwards * 0.15 # 0.25 is the time per forward pass
|
| 112 |
+
print('expected time:', time_for_maxlength / utils.MAX_DURATION * duration)
|
| 113 |
+
|
| 114 |
+
spare_time = 5
|
| 115 |
+
return max(10, time_for_maxlength / utils.MAX_DURATION * duration + spare_time)
|
| 116 |
+
|
| 117 |
|
| 118 |
+
def verify_model_params(model_id: str, input_audio, src_prompt: str, tar_prompt: str, cfg_scale_src: float,
|
| 119 |
+
oauth_token: gr.OAuthToken | None):
|
| 120 |
if input_audio is None:
|
| 121 |
raise gr.Error('Input audio missing!')
|
|
|
|
| 122 |
|
| 123 |
+
if tar_prompt == "":
|
| 124 |
+
raise gr.Error("Please provide a target prompt to edit the audio.")
|
| 125 |
+
|
| 126 |
+
if src_prompt != "":
|
| 127 |
+
if model_id == STABLEAUD and cfg_scale_src != 1:
|
| 128 |
+
gr.Info("Consider using Source Guidance Scale=1 for Stable Audio Open 1.0.")
|
| 129 |
+
elif model_id != STABLEAUD and cfg_scale_src != 3:
|
| 130 |
+
gr.Info(f"Consider using Source Guidance Scale=3 for {model_id}.")
|
| 131 |
+
|
| 132 |
+
if model_id == STABLEAUD:
|
| 133 |
+
if oauth_token is None:
|
| 134 |
+
raise gr.Error("You must be logged in to use Stable Audio Open 1.0. Please log in and try again.")
|
| 135 |
+
try:
|
| 136 |
+
huggingface_hub.get_hf_file_metadata(huggingface_hub.hf_hub_url(STABLEAUD, 'transformer/config.json'),
|
| 137 |
+
token=oauth_token.token)
|
| 138 |
+
print('Has Access')
|
| 139 |
+
# except huggingface_hub.utils._errors.GatedRepoError:
|
| 140 |
+
except huggingface_hub.errors.GatedRepoError:
|
| 141 |
+
raise gr.Error("You need to accept the license agreement to use Stable Audio Open 1.0. "
|
| 142 |
+
"Visit the <a href='https://huggingface.co/stabilityai/stable-audio-open-1.0'>"
|
| 143 |
+
"model page</a> to get access.")
|
| 144 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
|
| 146 |
+
@spaces.GPU(duration=get_duration)
|
| 147 |
+
def edit(input_audio,
|
| 148 |
+
model_id: str,
|
| 149 |
+
do_inversion: bool,
|
| 150 |
+
wts: Optional[torch.Tensor], zs: Optional[torch.Tensor], extra_info: Optional[List],
|
| 151 |
+
saved_inv_model: str,
|
| 152 |
+
source_prompt: str = "",
|
| 153 |
+
target_prompt: str = "",
|
| 154 |
+
steps: int = 200,
|
| 155 |
+
cfg_scale_src: float = 3.5,
|
| 156 |
+
cfg_scale_tar: float = 12,
|
| 157 |
+
t_start: int = 45,
|
| 158 |
+
randomize_seed: bool = True,
|
| 159 |
+
save_compute: bool = True,
|
| 160 |
+
oauth_token: Optional[gr.OAuthToken] = None):
|
| 161 |
print(model_id)
|
| 162 |
if model_id == LDM2:
|
| 163 |
ldm_stable = ldm2
|
| 164 |
elif model_id == LDM2_LARGE:
|
| 165 |
ldm_stable = ldm2_large
|
| 166 |
+
elif model_id == STABLEAUD:
|
| 167 |
+
ldm_stable = ldm_stableaud
|
| 168 |
else: # MUSIC
|
| 169 |
ldm_stable = ldm2_music
|
| 170 |
|
|
|
|
| 176 |
|
| 177 |
if input_audio is None:
|
| 178 |
raise gr.Error('Input audio missing!')
|
| 179 |
+
x0, _, duration = utils.load_audio(input_audio, ldm_stable.get_fn_STFT(), device=device,
|
| 180 |
+
stft=('stable-audio' not in ldm_stable.model_id), model_sr=ldm_stable.get_sr())
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
if wts is None or zs is None:
|
| 182 |
do_inversion = True
|
| 183 |
|
| 184 |
if do_inversion or randomize_seed: # always re-run inversion
|
| 185 |
+
zs_tensor, wts_tensor, extra_info_list = invert(ldm_stable=ldm_stable, x0=x0, prompt_src=source_prompt,
|
| 186 |
+
num_diffusion_steps=steps,
|
| 187 |
+
cfg_scale_src=cfg_scale_src,
|
| 188 |
+
duration=duration,
|
| 189 |
+
save_compute=save_compute)
|
|
|
|
|
|
|
|
|
|
| 190 |
wts = wts_tensor
|
| 191 |
zs = zs_tensor
|
| 192 |
+
extra_info = extra_info_list
|
|
|
|
| 193 |
saved_inv_model = model_id
|
| 194 |
do_inversion = False
|
| 195 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
wts_tensor = wts.to(device)
|
| 197 |
zs_tensor = zs.to(device)
|
| 198 |
+
extra_info_list = [e.to(device) for e in extra_info if e is not None]
|
| 199 |
|
| 200 |
+
output = sample(ldm_stable, zs_tensor, wts_tensor, extra_info_list, prompt_tar=target_prompt,
|
| 201 |
+
tstart=int(t_start / 100 * steps), cfg_scale_tar=cfg_scale_tar, duration=duration,
|
| 202 |
+
save_compute=save_compute)
|
|
|
|
|
|
|
| 203 |
|
| 204 |
+
return output, wts.cpu(), zs.cpu(), [e.cpu() for e in extra_info if e is not None], saved_inv_model, do_inversion
|
| 205 |
# return output, wtszs_file, saved_inv_model, do_inversion
|
| 206 |
|
| 207 |
|
| 208 |
def get_example():
|
| 209 |
case = [
|
| 210 |
+
['Examples/Beethoven.mp3',
|
| 211 |
'',
|
| 212 |
'A recording of an arcade game soundtrack.',
|
| 213 |
45,
|
| 214 |
'cvssp/audioldm2-music',
|
| 215 |
'27s',
|
| 216 |
+
'Examples/Beethoven_arcade.mp3',
|
| 217 |
],
|
| 218 |
+
['Examples/Beethoven.mp3',
|
| 219 |
'A high quality recording of wind instruments and strings playing.',
|
| 220 |
'A high quality recording of a piano playing.',
|
| 221 |
45,
|
| 222 |
'cvssp/audioldm2-music',
|
| 223 |
'27s',
|
| 224 |
+
'Examples/Beethoven_piano.mp3',
|
| 225 |
+
],
|
| 226 |
+
['Examples/Beethoven.mp3',
|
| 227 |
+
'',
|
| 228 |
+
'Heavy Rock.',
|
| 229 |
+
40,
|
| 230 |
+
'stabilityai/stable-audio-open-1.0',
|
| 231 |
+
'27s',
|
| 232 |
+
'Examples/Beethoven_rock.mp3',
|
| 233 |
],
|
| 234 |
+
['Examples/ModalJazz.mp3',
|
| 235 |
'Trumpets playing alongside a piano, bass and drums in an upbeat old-timey cool jazz song.',
|
| 236 |
'A banjo playing alongside a piano, bass and drums in an upbeat old-timey cool country song.',
|
| 237 |
45,
|
| 238 |
'cvssp/audioldm2-music',
|
| 239 |
'106s',
|
| 240 |
+
'Examples/ModalJazz_banjo.mp3',],
|
| 241 |
+
['Examples/Shadows.mp3',
|
| 242 |
+
'',
|
| 243 |
+
'8-bit arcade game soundtrack.',
|
| 244 |
+
40,
|
| 245 |
+
'stabilityai/stable-audio-open-1.0',
|
| 246 |
+
'34s',
|
| 247 |
+
'Examples/Shadows_arcade.mp3',],
|
| 248 |
+
['Examples/Cat.mp3',
|
| 249 |
'',
|
| 250 |
'A dog barking.',
|
| 251 |
75,
|
| 252 |
'cvssp/audioldm2-large',
|
| 253 |
'10s',
|
| 254 |
+
'Examples/Cat_dog.mp3',]
|
| 255 |
]
|
| 256 |
return case
|
| 257 |
|
| 258 |
|
| 259 |
intro = """
|
| 260 |
+
<h1 style="font-weight: 1000; text-align: center; margin: 0px;"> ZETA Editing 🎧 </h1>
|
| 261 |
+
<h2 style="font-weight: 1000; text-align: center; margin: 0px;">
|
| 262 |
+
Zero-Shot Text-Based Audio Editing Using DDPM Inversion 🎛️ </h2>
|
| 263 |
+
<h3 style="margin-top: 0px; margin-bottom: 10px; text-align: center;">
|
| 264 |
<a href="https://arxiv.org/abs/2402.10009">[Paper]</a> |
|
| 265 |
<a href="https://hilamanor.github.io/AudioEditing/">[Project page]</a> |
|
| 266 |
<a href="https://github.com/HilaManor/AudioEditingCode">[Code]</a>
|
| 267 |
</h3>
|
| 268 |
|
| 269 |
+
<p style="font-size: 1rem; line-height: 1.2em;">
|
|
|
|
| 270 |
For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
|
| 271 |
<a href="https://huggingface.co/spaces/hilamanor/audioEditing?duplicate=true">
|
| 272 |
+
<img style="margin-top: 0em; margin-bottom: 0em; display:inline" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" >
|
| 273 |
+
</a>
|
| 274 |
+
</p>
|
| 275 |
+
<p style="margin: 0px;">
|
| 276 |
+
<b>NEW - 15.10.24:</b> You can now edit using <b>Stable Audio Open 1.0</b>.
|
| 277 |
+
You must be <b>logged in</b> after accepting the
|
| 278 |
+
<b><a href="https://huggingface.co/stabilityai/stable-audio-open-1.0">license agreement</a></b> to use it.</br>
|
| 279 |
+
</p>
|
| 280 |
+
<ul style="padding-left:40px; line-height:normal;">
|
| 281 |
+
<li style="margin: 0px;">Prompts behave differently - e.g.,
|
| 282 |
+
try "8-bit arcade" directly instead of "a recording of...". Check out the new examples below!</li>
|
| 283 |
+
<li style="margin: 0px;">Try to play around <code>T-start=40%</code>.</li>
|
| 284 |
+
<li style="margin: 0px;">Under "More Options": Use <code>Source Guidance Scale=1</code>,
|
| 285 |
+
and you can try fewer timesteps (even 20!).</li>
|
| 286 |
+
<li style="margin: 0px;">Stable Audio Open is a general-audio model.
|
| 287 |
+
For better music editing, duplicate the space and change to a
|
| 288 |
+
<a href="https://huggingface.co/models?other=base_model:finetune:stabilityai/stable-audio-open-1.0">
|
| 289 |
+
fine-tuned model for music</a>.</li>
|
| 290 |
+
</ul>
|
| 291 |
+
<p>
|
| 292 |
+
<b>NEW - 15.10.24:</b> Parallel editing is enabled by default.
|
| 293 |
+
To disable, uncheck <code>Efficient editing</code> under "More Options".
|
| 294 |
+
Saves a bit of time.
|
| 295 |
</p>
|
|
|
|
| 296 |
"""
|
| 297 |
|
| 298 |
+
|
| 299 |
help = """
|
| 300 |
<div style="font-size:medium">
|
| 301 |
<b>Instructions:</b><br>
|
|
|
|
| 303 |
<li>You must provide an input audio and a target prompt to edit the audio. </li>
|
| 304 |
<li>T<sub>start</sub> is used to control the tradeoff between fidelity to the original signal and text-adhearance.
|
| 305 |
Lower value -> favor fidelity. Higher value -> apply a stronger edit.</li>
|
| 306 |
+
<li>Make sure that you use a model version that is suitable for your input audio.
|
| 307 |
+
For example, use AudioLDM2-music for music while AudioLDM2-large for general audio.
|
| 308 |
</li>
|
| 309 |
<li>You can additionally provide a source prompt to guide even further the editing process.</li>
|
| 310 |
<li>Longer input will take more time.</li>
|
| 311 |
<li><strong>Unlimited length</strong>: This space automatically trims input audio to a maximum length of 30 seconds.
|
| 312 |
+
For unlimited length, duplicated the space, and change the
|
| 313 |
+
<code style="display:inline; background-color: lightgrey;">MAX_DURATION</code> parameter
|
| 314 |
+
inside <code style="display:inline; background-color: lightgrey;">utils.py</code>
|
| 315 |
+
to <code style="display:inline; background-color: lightgrey;">None</code>.
|
| 316 |
+
</li>
|
| 317 |
</ul>
|
| 318 |
</div>
|
| 319 |
|
| 320 |
"""
|
| 321 |
|
| 322 |
+
css = '.gradio-container {max-width: 1000px !important; padding-top: 1.5rem !important;}' \
|
| 323 |
+
'.audio-upload .wrap {min-height: 0px;}'
|
| 324 |
+
|
| 325 |
+
# with gr.Blocks(css='style.css') as demo:
|
| 326 |
+
with gr.Blocks(css=css) as demo:
|
| 327 |
def reset_do_inversion(do_inversion_user, do_inversion):
|
| 328 |
# do_inversion = gr.State(value=True)
|
| 329 |
do_inversion = True
|
|
|
|
| 342 |
return do_inversion_user, do_inversion
|
| 343 |
|
| 344 |
gr.HTML(intro)
|
| 345 |
+
|
| 346 |
wts = gr.State()
|
| 347 |
zs = gr.State()
|
| 348 |
+
extra_info = gr.State()
|
|
|
|
| 349 |
saved_inv_model = gr.State()
|
|
|
|
|
|
|
|
|
|
| 350 |
do_inversion = gr.State(value=True) # To save some runtime when editing the same thing over and over
|
| 351 |
do_inversion_user = gr.State(value=False)
|
| 352 |
|
| 353 |
with gr.Group():
|
| 354 |
+
gr.Markdown("💡 **note**: input longer than **30 sec** is automatically trimmed "
|
| 355 |
+
"(for unlimited input, see the Help section below)")
|
| 356 |
+
with gr.Row(equal_height=True):
|
| 357 |
+
input_audio = gr.Audio(sources=["upload", "microphone"], type="filepath",
|
| 358 |
+
editable=True, label="Input Audio", interactive=True, scale=1, format='wav',
|
| 359 |
+
elem_classes=['audio-upload'])
|
| 360 |
+
output_audio = gr.Audio(label="Edited Audio", interactive=False, scale=1, format='wav')
|
| 361 |
|
| 362 |
with gr.Row():
|
| 363 |
tar_prompt = gr.Textbox(label="Prompt", info="Describe your desired edited output",
|
|
|
|
| 367 |
with gr.Row():
|
| 368 |
t_start = gr.Slider(minimum=15, maximum=85, value=45, step=1, label="T-start (%)", interactive=True, scale=3,
|
| 369 |
info="Lower T-start -> closer to original audio. Higher T-start -> stronger edit.")
|
| 370 |
+
model_id = gr.Dropdown(label="Model Version",
|
| 371 |
+
choices=[LDM2,
|
| 372 |
+
LDM2_LARGE,
|
| 373 |
+
MUSIC,
|
| 374 |
+
STABLEAUD],
|
| 375 |
+
info="Choose a checkpoint suitable for your audio and edit",
|
| 376 |
value="cvssp/audioldm2-music", interactive=True, type="value", scale=2)
|
|
|
|
| 377 |
with gr.Row():
|
| 378 |
+
submit = gr.Button("Edit", variant="primary", scale=3)
|
| 379 |
+
gr.LoginButton(value="Login to HF (For Stable Audio)", scale=1)
|
| 380 |
|
| 381 |
with gr.Accordion("More Options", open=False):
|
| 382 |
with gr.Row():
|
|
|
|
| 384 |
info="Optional: Describe the original audio input",
|
| 385 |
placeholder="A recording of a happy upbeat classical music piece",)
|
| 386 |
|
| 387 |
+
with gr.Row(equal_height=True):
|
| 388 |
cfg_scale_src = gr.Number(value=3, minimum=0.5, maximum=25, precision=None,
|
| 389 |
label="Source Guidance Scale", interactive=True, scale=1)
|
| 390 |
cfg_scale_tar = gr.Number(value=12, minimum=0.5, maximum=25, precision=None,
|
| 391 |
label="Target Guidance Scale", interactive=True, scale=1)
|
| 392 |
+
steps = gr.Number(value=50, step=1, minimum=10, maximum=300,
|
| 393 |
info="Higher values (e.g. 200) yield higher-quality generation.",
|
| 394 |
+
label="Num Diffusion Steps", interactive=True, scale=2)
|
| 395 |
+
with gr.Row(equal_height=True):
|
| 396 |
seed = gr.Number(value=0, precision=0, label="Seed", interactive=True)
|
| 397 |
randomize_seed = gr.Checkbox(label='Randomize seed', value=False)
|
| 398 |
+
save_compute = gr.Checkbox(label='Efficient editing', value=True)
|
| 399 |
length = gr.Number(label="Length", interactive=False, visible=False)
|
| 400 |
|
| 401 |
with gr.Accordion("Help💡", open=False):
|
| 402 |
gr.HTML(help)
|
| 403 |
|
| 404 |
submit.click(
|
| 405 |
+
fn=verify_model_params,
|
| 406 |
+
inputs=[model_id, input_audio, src_prompt, tar_prompt, cfg_scale_src],
|
| 407 |
+
outputs=[]
|
| 408 |
+
).success(
|
| 409 |
+
fn=randomize_seed_fn, inputs=[seed, randomize_seed], outputs=[seed], queue=False
|
| 410 |
+
).then(
|
| 411 |
+
fn=clear_do_inversion_user, inputs=[do_inversion_user], outputs=[do_inversion_user]
|
| 412 |
+
).then(
|
| 413 |
+
fn=edit,
|
| 414 |
+
inputs=[input_audio,
|
| 415 |
+
model_id,
|
| 416 |
+
do_inversion,
|
| 417 |
+
wts, zs, extra_info,
|
| 418 |
+
saved_inv_model,
|
| 419 |
+
src_prompt,
|
| 420 |
+
tar_prompt,
|
| 421 |
+
steps,
|
| 422 |
+
cfg_scale_src,
|
| 423 |
+
cfg_scale_tar,
|
| 424 |
+
t_start,
|
| 425 |
+
randomize_seed,
|
| 426 |
+
save_compute,
|
| 427 |
+
],
|
| 428 |
+
outputs=[output_audio, wts, zs, extra_info, saved_inv_model, do_inversion]
|
| 429 |
+
).success(
|
| 430 |
+
fn=post_match_do_inversion,
|
| 431 |
+
inputs=[do_inversion_user, do_inversion],
|
| 432 |
+
outputs=[do_inversion_user, do_inversion]
|
| 433 |
+
)
|
| 434 |
|
| 435 |
# If sources changed we have to rerun inversion
|
| 436 |
+
gr.on(
|
| 437 |
+
triggers=[input_audio.change, src_prompt.change, model_id.change, cfg_scale_src.change,
|
| 438 |
+
steps.change, save_compute.change],
|
| 439 |
+
fn=reset_do_inversion,
|
| 440 |
+
inputs=[do_inversion_user, do_inversion],
|
| 441 |
+
outputs=[do_inversion_user, do_inversion]
|
| 442 |
+
)
|
| 443 |
|
| 444 |
gr.Examples(
|
| 445 |
label="Examples",
|
inversion_utils.py
CHANGED
|
@@ -1,341 +1,135 @@
|
|
| 1 |
import torch
|
| 2 |
from tqdm import tqdm
|
| 3 |
-
|
| 4 |
-
from typing import List, Optional, Dict, Union
|
| 5 |
from models import PipelineWrapper
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
def mu_tilde(model, xt, x0, timestep):
|
| 9 |
-
"mu_tilde(x_t, x_0) DDPM paper eq. 7"
|
| 10 |
-
prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
|
| 11 |
-
alpha_prod_t_prev = model.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 \
|
| 12 |
-
else model.scheduler.final_alpha_cumprod
|
| 13 |
-
alpha_t = model.scheduler.alphas[timestep]
|
| 14 |
-
beta_t = 1 - alpha_t
|
| 15 |
-
alpha_bar = model.scheduler.alphas_cumprod[timestep]
|
| 16 |
-
return ((alpha_prod_t_prev ** 0.5 * beta_t) / (1-alpha_bar)) * x0 + \
|
| 17 |
-
((alpha_t**0.5 * (1-alpha_prod_t_prev)) / (1 - alpha_bar)) * xt
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
def sample_xts_from_x0(model, x0, num_inference_steps=50, x_prev_mode=False):
|
| 21 |
-
"""
|
| 22 |
-
Samples from P(x_1:T|x_0)
|
| 23 |
-
"""
|
| 24 |
-
# torch.manual_seed(43256465436)
|
| 25 |
-
alpha_bar = model.model.scheduler.alphas_cumprod
|
| 26 |
-
sqrt_one_minus_alpha_bar = (1-alpha_bar) ** 0.5
|
| 27 |
-
alphas = model.model.scheduler.alphas
|
| 28 |
-
# betas = 1 - alphas
|
| 29 |
-
variance_noise_shape = (
|
| 30 |
-
num_inference_steps + 1,
|
| 31 |
-
model.model.unet.config.in_channels,
|
| 32 |
-
# model.unet.sample_size,
|
| 33 |
-
# model.unet.sample_size)
|
| 34 |
-
x0.shape[-2],
|
| 35 |
-
x0.shape[-1])
|
| 36 |
-
|
| 37 |
-
timesteps = model.model.scheduler.timesteps.to(model.device)
|
| 38 |
-
t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
|
| 39 |
-
xts = torch.zeros(variance_noise_shape).to(x0.device)
|
| 40 |
-
xts[0] = x0
|
| 41 |
-
x_prev = x0
|
| 42 |
-
for t in reversed(timesteps):
|
| 43 |
-
# idx = t_to_idx[int(t)]
|
| 44 |
-
idx = num_inference_steps-t_to_idx[int(t)]
|
| 45 |
-
if x_prev_mode:
|
| 46 |
-
xts[idx] = x_prev * (alphas[t] ** 0.5) + torch.randn_like(x0) * ((1-alphas[t]) ** 0.5)
|
| 47 |
-
x_prev = xts[idx].clone()
|
| 48 |
-
else:
|
| 49 |
-
xts[idx] = x0 * (alpha_bar[t] ** 0.5) + torch.randn_like(x0) * sqrt_one_minus_alpha_bar[t]
|
| 50 |
-
# xts = torch.cat([xts, x0 ],dim = 0)
|
| 51 |
-
|
| 52 |
-
return xts
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
def forward_step(model, model_output, timestep, sample):
|
| 56 |
-
next_timestep = min(model.scheduler.config.num_train_timesteps - 2,
|
| 57 |
-
timestep + model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps)
|
| 58 |
-
|
| 59 |
-
# 2. compute alphas, betas
|
| 60 |
-
alpha_prod_t = model.scheduler.alphas_cumprod[timestep]
|
| 61 |
-
# alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep] if next_ltimestep >= 0 \
|
| 62 |
-
# else self.scheduler.final_alpha_cumprod
|
| 63 |
-
|
| 64 |
-
beta_prod_t = 1 - alpha_prod_t
|
| 65 |
-
|
| 66 |
-
# 3. compute predicted original sample from predicted noise also called
|
| 67 |
-
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
| 68 |
-
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
| 69 |
-
|
| 70 |
-
# 5. TODO: simple noising implementatiom
|
| 71 |
-
next_sample = model.scheduler.add_noise(pred_original_sample, model_output, torch.LongTensor([next_timestep]))
|
| 72 |
-
return next_sample
|
| 73 |
|
| 74 |
|
| 75 |
def inversion_forward_process(model: PipelineWrapper,
|
| 76 |
x0: torch.Tensor,
|
| 77 |
etas: Optional[float] = None,
|
| 78 |
-
prog_bar: bool = False,
|
| 79 |
prompts: List[str] = [""],
|
| 80 |
cfg_scales: List[float] = [3.5],
|
| 81 |
num_inference_steps: int = 50,
|
| 82 |
-
eps: Optional[float] = None,
|
| 83 |
-
cutoff_points: Optional[List[float]] = None,
|
| 84 |
numerical_fix: bool = False,
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
raise NotImplementedError("How do you split cfg_scales for hspace? TODO")
|
| 90 |
-
|
| 91 |
if len(prompts) > 1 or prompts[0] != "":
|
| 92 |
text_embeddings_hidden_states, text_embeddings_class_labels, \
|
| 93 |
text_embeddings_boolean_prompt_mask = model.encode_text(prompts)
|
| 94 |
-
# text_embeddings = encode_text(model, prompt)
|
| 95 |
-
|
| 96 |
-
# # classifier free guidance
|
| 97 |
-
batch_size = len(prompts)
|
| 98 |
-
cfg_scales_tensor = torch.ones((batch_size, *x0.shape[1:]), device=model.device, dtype=x0.dtype)
|
| 99 |
-
|
| 100 |
-
# if len(prompts) > 1:
|
| 101 |
-
# if cutoff_points is None:
|
| 102 |
-
# cutoff_points = [i * 1 / batch_size for i in range(1, batch_size)]
|
| 103 |
-
# if len(cfg_scales) == 1:
|
| 104 |
-
# cfg_scales *= batch_size
|
| 105 |
-
# elif len(cfg_scales) < batch_size:
|
| 106 |
-
# raise ValueError("Not enough target CFG scales")
|
| 107 |
-
|
| 108 |
-
# cutoff_points = [int(x * cfg_scales_tensor.shape[2]) for x in cutoff_points]
|
| 109 |
-
# cutoff_points = [0, *cutoff_points, cfg_scales_tensor.shape[2]]
|
| 110 |
|
| 111 |
-
#
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
# else:
|
| 119 |
-
cfg_scales_tensor *= cfg_scales[0]
|
| 120 |
|
| 121 |
-
uncond_embedding_hidden_states, uncond_embedding_class_lables, uncond_boolean_prompt_mask = model.encode_text([""])
|
| 122 |
-
# uncond_embedding = encode_text(model, "")
|
| 123 |
timesteps = model.model.scheduler.timesteps.to(model.device)
|
| 124 |
-
variance_noise_shape = (
|
| 125 |
-
num_inference_steps,
|
| 126 |
-
model.model.unet.config.in_channels,
|
| 127 |
-
# model.unet.sample_size,
|
| 128 |
-
# model.unet.sample_size)
|
| 129 |
-
x0.shape[-2],
|
| 130 |
-
x0.shape[-1])
|
| 131 |
|
| 132 |
-
if
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
hspaces = []
|
| 143 |
-
skipconns = []
|
| 144 |
-
t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
|
| 145 |
xt = x0
|
| 146 |
-
|
| 147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
-
for t in op:
|
| 150 |
-
# idx = t_to_idx[int(t)]
|
| 151 |
-
idx = num_inference_steps - t_to_idx[int(t)] - 1
|
| 152 |
# 1. predict noise residual
|
| 153 |
-
|
| 154 |
-
|
| 155 |
|
| 156 |
with torch.no_grad():
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
if len(prompts) > 1 or prompts[0] != "":
|
| 171 |
# # classifier free guidance
|
| 172 |
-
noise_pred = out
|
| 173 |
-
(cfg_scales_tensor * (cond_out.sample - out.sample.expand(batch_size, -1, -1, -1))
|
| 174 |
-
).sum(axis=0).unsqueeze(0)
|
| 175 |
-
if extract_h_space or extract_skipconns:
|
| 176 |
-
noise_h_space = out_hspace + cfg_scales[0] * (cond_out_hspace - out_hspace)
|
| 177 |
-
if extract_skipconns:
|
| 178 |
-
noise_skipconns = {k: [out_skipconns[k][j] + cfg_scales[0] *
|
| 179 |
-
(cond_out_skipconns[k][j] - out_skipconns[k][j])
|
| 180 |
-
for j in range(len(out_skipconns[k]))]
|
| 181 |
-
for k in out_skipconns}
|
| 182 |
-
else:
|
| 183 |
-
noise_pred = out.sample
|
| 184 |
-
if extract_h_space or extract_skipconns:
|
| 185 |
-
noise_h_space = out_hspace
|
| 186 |
-
if extract_skipconns:
|
| 187 |
-
noise_skipconns = out_skipconns
|
| 188 |
-
if extract_h_space or extract_skipconns:
|
| 189 |
-
hspaces.append(noise_h_space)
|
| 190 |
-
if extract_skipconns:
|
| 191 |
-
skipconns.append(noise_skipconns)
|
| 192 |
-
|
| 193 |
-
if eta_is_zero:
|
| 194 |
-
# 2. compute more noisy image and set x_t -> x_t+1
|
| 195 |
-
xt = forward_step(model.model, noise_pred, t, xt)
|
| 196 |
else:
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
alpha_prod_t_prev = model.get_alpha_prod_t_prev(prev_timestep)
|
| 210 |
-
variance = model.get_variance(t, prev_timestep)
|
| 211 |
-
|
| 212 |
-
if model.model.scheduler.config.prediction_type == 'epsilon':
|
| 213 |
-
radom_noise_pred = noise_pred
|
| 214 |
-
elif model.model.scheduler.config.prediction_type == 'v_prediction':
|
| 215 |
-
radom_noise_pred = (alpha_bar[t] ** 0.5) * noise_pred + ((1 - alpha_bar[t]) ** 0.5) * xt
|
| 216 |
-
|
| 217 |
-
pred_sample_direction = (1 - alpha_prod_t_prev - etas[idx] * variance) ** (0.5) * radom_noise_pred
|
| 218 |
-
|
| 219 |
-
mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
|
| 220 |
-
|
| 221 |
-
z = (xtm1 - mu_xt) / (etas[idx] * variance ** 0.5)
|
| 222 |
-
|
| 223 |
-
zs[idx] = z
|
| 224 |
-
|
| 225 |
-
# correction to avoid error accumulation
|
| 226 |
-
if numerical_fix:
|
| 227 |
-
xtm1 = mu_xt + (etas[idx] * variance ** 0.5)*z
|
| 228 |
-
xts[idx] = xtm1
|
| 229 |
|
| 230 |
if zs is not None:
|
| 231 |
# zs[-1] = torch.zeros_like(zs[-1])
|
| 232 |
zs[0] = torch.zeros_like(zs[0])
|
| 233 |
# zs_cycle[0] = torch.zeros_like(zs[0])
|
| 234 |
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
return xt, zs, xts, hspaces
|
| 238 |
-
|
| 239 |
-
if extract_skipconns:
|
| 240 |
-
hspaces = torch.concat(hspaces, axis=0)
|
| 241 |
-
return xt, zs, xts, hspaces, skipconns
|
| 242 |
-
|
| 243 |
-
return xt, zs, xts
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
def reverse_step(model, model_output, timestep, sample, eta=0, variance_noise=None):
|
| 247 |
-
# 1. get previous step value (=t-1)
|
| 248 |
-
prev_timestep = timestep - model.model.scheduler.config.num_train_timesteps // \
|
| 249 |
-
model.model.scheduler.num_inference_steps
|
| 250 |
-
# 2. compute alphas, betas
|
| 251 |
-
alpha_prod_t = model.model.scheduler.alphas_cumprod[timestep]
|
| 252 |
-
alpha_prod_t_prev = model.get_alpha_prod_t_prev(prev_timestep)
|
| 253 |
-
beta_prod_t = 1 - alpha_prod_t
|
| 254 |
-
# 3. compute predicted original sample from predicted noise also called
|
| 255 |
-
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
| 256 |
-
if model.model.scheduler.config.prediction_type == 'epsilon':
|
| 257 |
-
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
| 258 |
-
elif model.model.scheduler.config.prediction_type == 'v_prediction':
|
| 259 |
-
pred_original_sample = (alpha_prod_t ** 0.5) * sample - (beta_prod_t ** 0.5) * model_output
|
| 260 |
-
|
| 261 |
-
# 5. compute variance: "sigma_t(η)" -> see formula (16)
|
| 262 |
-
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
|
| 263 |
-
# variance = self.scheduler._get_variance(timestep, prev_timestep)
|
| 264 |
-
variance = model.get_variance(timestep, prev_timestep)
|
| 265 |
-
# std_dev_t = eta * variance ** (0.5)
|
| 266 |
-
# Take care of asymetric reverse process (asyrp)
|
| 267 |
-
if model.model.scheduler.config.prediction_type == 'epsilon':
|
| 268 |
-
model_output_direction = model_output
|
| 269 |
-
elif model.model.scheduler.config.prediction_type == 'v_prediction':
|
| 270 |
-
model_output_direction = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
|
| 271 |
-
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
| 272 |
-
# pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output_direction
|
| 273 |
-
pred_sample_direction = (1 - alpha_prod_t_prev - eta * variance) ** (0.5) * model_output_direction
|
| 274 |
-
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
| 275 |
-
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
|
| 276 |
-
# 8. Add noice if eta > 0
|
| 277 |
-
if eta > 0:
|
| 278 |
-
if variance_noise is None:
|
| 279 |
-
variance_noise = torch.randn(model_output.shape, device=model.device)
|
| 280 |
-
sigma_z = eta * variance ** (0.5) * variance_noise
|
| 281 |
-
prev_sample = prev_sample + sigma_z
|
| 282 |
-
|
| 283 |
-
return prev_sample
|
| 284 |
|
| 285 |
|
| 286 |
def inversion_reverse_process(model: PipelineWrapper,
|
| 287 |
xT: torch.Tensor,
|
| 288 |
-
|
| 289 |
-
fix_alpha: float = 0.1,
|
| 290 |
etas: float = 0,
|
| 291 |
prompts: List[str] = [""],
|
| 292 |
neg_prompts: List[str] = [""],
|
| 293 |
cfg_scales: Optional[List[float]] = None,
|
| 294 |
-
prog_bar: bool = False,
|
| 295 |
zs: Optional[List[torch.Tensor]] = None,
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
zero_out_resconns: Optional[Union[int, List]] = None,
|
| 302 |
-
asyrp: bool = False,
|
| 303 |
-
extract_h_space: bool = False,
|
| 304 |
-
extract_skipconns: bool = False):
|
| 305 |
-
|
| 306 |
-
batch_size = len(prompts)
|
| 307 |
|
| 308 |
text_embeddings_hidden_states, text_embeddings_class_labels, \
|
| 309 |
text_embeddings_boolean_prompt_mask = model.encode_text(prompts)
|
| 310 |
-
|
| 311 |
-
uncond_boolean_prompt_mask = model.encode_text(neg_prompts
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
cfg_scales_tensor = torch.ones((batch_size, *xT.shape[1:]), device=model.device, dtype=xT.dtype)
|
| 317 |
-
|
| 318 |
-
# if batch_size > 1:
|
| 319 |
-
# if cutoff_points is None:
|
| 320 |
-
# cutoff_points = [i * 1 / batch_size for i in range(1, batch_size)]
|
| 321 |
-
# if len(cfg_scales) == 1:
|
| 322 |
-
# cfg_scales *= batch_size
|
| 323 |
-
# elif len(cfg_scales) < batch_size:
|
| 324 |
-
# raise ValueError("Not enough target CFG scales")
|
| 325 |
-
|
| 326 |
-
# cutoff_points = [int(x * cfg_scales_tensor.shape[2]) for x in cutoff_points]
|
| 327 |
-
# cutoff_points = [0, *cutoff_points, cfg_scales_tensor.shape[2]]
|
| 328 |
|
| 329 |
-
|
| 330 |
-
# cfg_scales_tensor[i, :, end:] = 0
|
| 331 |
-
# cfg_scales_tensor[i, :, :start] = 0
|
| 332 |
-
# masks[i, :, end:] = 0
|
| 333 |
-
# masks[i, :, :start] = 0
|
| 334 |
-
# cfg_scales_tensor[i] *= cfg_scales[i]
|
| 335 |
-
# cfg_scales_tensor = T.functional.gaussian_blur(cfg_scales_tensor, kernel_size=15, sigma=1)
|
| 336 |
-
# masks = T.functional.gaussian_blur(masks, kernel_size=15, sigma=1)
|
| 337 |
-
# else:
|
| 338 |
-
cfg_scales_tensor *= cfg_scales[0]
|
| 339 |
|
| 340 |
if etas is None:
|
| 341 |
etas = 0
|
|
@@ -344,107 +138,71 @@ def inversion_reverse_process(model: PipelineWrapper,
|
|
| 344 |
assert len(etas) == model.model.scheduler.num_inference_steps
|
| 345 |
timesteps = model.model.scheduler.timesteps.to(model.device)
|
| 346 |
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
|
| 351 |
-
t_to_idx = {int(v): k for k, v in enumerate(timesteps[-zs.shape[0]:])}
|
| 352 |
-
hspaces = []
|
| 353 |
-
skipconns = []
|
| 354 |
-
|
| 355 |
-
for it, t in enumerate(op):
|
| 356 |
-
# idx = t_to_idx[int(t)]
|
| 357 |
-
idx = model.model.scheduler.num_inference_steps - t_to_idx[int(t)] - \
|
| 358 |
-
(model.model.scheduler.num_inference_steps - zs.shape[0] + 1)
|
| 359 |
# # Unconditional embedding
|
| 360 |
with torch.no_grad():
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
timestep=t,
|
| 385 |
encoder_hidden_states=text_embeddings_hidden_states,
|
| 386 |
class_labels=text_embeddings_class_labels,
|
| 387 |
encoder_attention_mask=text_embeddings_boolean_prompt_mask,
|
| 388 |
-
|
| 389 |
-
(cfg_scales[0] / (cfg_scales[0] + 1)) *
|
| 390 |
-
(hspace_add[-zs.shape[0]:][it] if hspace_add.shape[0] > 1
|
| 391 |
-
else hspace_add)),
|
| 392 |
-
replace_h_space=(None if hspace_replace is None else
|
| 393 |
-
(hspace_replace[-zs.shape[0]:][it].unsqueeze(0) if hspace_replace.shape[0] > 1
|
| 394 |
-
else hspace_replace)),
|
| 395 |
-
zero_out_resconns=zero_out_resconns,
|
| 396 |
-
replace_skip_conns=(None if skipconns_replace is None else
|
| 397 |
-
(skipconns_replace[-zs.shape[0]:][it] if len(skipconns_replace) > 1
|
| 398 |
-
else skipconns_replace))
|
| 399 |
-
) # encoder_hidden_states = text_embeddings)
|
| 400 |
|
| 401 |
z = zs[idx] if zs is not None else None
|
| 402 |
-
# print(f'idx: {idx}')
|
| 403 |
-
# print(f't: {t}')
|
| 404 |
z = z.unsqueeze(0)
|
| 405 |
-
#
|
| 406 |
-
|
| 407 |
-
# # classifier free guidance
|
| 408 |
-
# noise_pred = uncond_out.sample + cfg_scales_tensor * (cond_out.sample - uncond_out.sample)
|
| 409 |
-
noise_pred = uncond_out.sample + \
|
| 410 |
-
(cfg_scales_tensor * (cond_out.sample - uncond_out.sample.expand(batch_size, -1, -1, -1))
|
| 411 |
-
).sum(axis=0).unsqueeze(0)
|
| 412 |
-
if extract_h_space or extract_skipconns:
|
| 413 |
-
noise_h_space = out_hspace + cfg_scales[0] * (cond_out_hspace - out_hspace)
|
| 414 |
-
if extract_skipconns:
|
| 415 |
-
noise_skipconns = {k: [out_skipconns[k][j] + cfg_scales[0] *
|
| 416 |
-
(cond_out_skipconns[k][j] - out_skipconns[k][j])
|
| 417 |
-
for j in range(len(out_skipconns[k]))]
|
| 418 |
-
for k in out_skipconns}
|
| 419 |
-
else:
|
| 420 |
-
noise_pred = uncond_out.sample
|
| 421 |
-
if extract_h_space or extract_skipconns:
|
| 422 |
-
noise_h_space = out_hspace
|
| 423 |
-
if extract_skipconns:
|
| 424 |
-
noise_skipconns = out_skipconns
|
| 425 |
-
|
| 426 |
-
if extract_h_space or extract_skipconns:
|
| 427 |
-
hspaces.append(noise_h_space)
|
| 428 |
-
if extract_skipconns:
|
| 429 |
-
skipconns.append(noise_skipconns)
|
| 430 |
|
| 431 |
# 2. compute less noisy image and set x_t -> x_t-1
|
| 432 |
-
xt =
|
| 433 |
-
|
| 434 |
-
# xt = controller.step_callback(xt)
|
| 435 |
-
|
| 436 |
-
# "fix" xt
|
| 437 |
-
apply_fix = ((skips.max() - skips) > it)
|
| 438 |
-
if apply_fix.any():
|
| 439 |
-
apply_fix = (apply_fix * fix_alpha).unsqueeze(1).unsqueeze(2).unsqueeze(3).to(xT.device)
|
| 440 |
-
xt = (masks * (xt.expand(batch_size, -1, -1, -1) * (1 - apply_fix) +
|
| 441 |
-
apply_fix * xT[skips.max() - it - 1].expand(batch_size, -1, -1, -1))
|
| 442 |
-
).sum(axis=0).unsqueeze(0)
|
| 443 |
-
|
| 444 |
-
if extract_h_space:
|
| 445 |
-
return xt, zs, torch.concat(hspaces, axis=0)
|
| 446 |
-
|
| 447 |
-
if extract_skipconns:
|
| 448 |
-
return xt, zs, torch.concat(hspaces, axis=0), skipconns
|
| 449 |
|
|
|
|
| 450 |
return xt, zs
|
|
|
|
| 1 |
import torch
|
| 2 |
from tqdm import tqdm
|
| 3 |
+
from typing import List, Optional, Tuple
|
|
|
|
| 4 |
from models import PipelineWrapper
|
| 5 |
+
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
def inversion_forward_process(model: PipelineWrapper,
|
| 9 |
x0: torch.Tensor,
|
| 10 |
etas: Optional[float] = None,
|
|
|
|
| 11 |
prompts: List[str] = [""],
|
| 12 |
cfg_scales: List[float] = [3.5],
|
| 13 |
num_inference_steps: int = 50,
|
|
|
|
|
|
|
| 14 |
numerical_fix: bool = False,
|
| 15 |
+
duration: Optional[float] = None,
|
| 16 |
+
first_order: bool = False,
|
| 17 |
+
save_compute: bool = True,
|
| 18 |
+
progress=gr.Progress()) -> Tuple:
|
|
|
|
|
|
|
| 19 |
if len(prompts) > 1 or prompts[0] != "":
|
| 20 |
text_embeddings_hidden_states, text_embeddings_class_labels, \
|
| 21 |
text_embeddings_boolean_prompt_mask = model.encode_text(prompts)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
+
# In the forward negative prompts are not supported currently (TODO)
|
| 24 |
+
uncond_embeddings_hidden_states, uncond_embeddings_class_lables, uncond_boolean_prompt_mask = model.encode_text(
|
| 25 |
+
[""], negative=True, save_compute=save_compute, cond_length=text_embeddings_class_labels.shape[1]
|
| 26 |
+
if text_embeddings_class_labels is not None else None)
|
| 27 |
+
else:
|
| 28 |
+
uncond_embeddings_hidden_states, uncond_embeddings_class_lables, uncond_boolean_prompt_mask = model.encode_text(
|
| 29 |
+
[""], negative=True, save_compute=False)
|
|
|
|
|
|
|
| 30 |
|
|
|
|
|
|
|
| 31 |
timesteps = model.model.scheduler.timesteps.to(model.device)
|
| 32 |
+
variance_noise_shape = model.get_noise_shape(x0, num_inference_steps)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
+
if type(etas) in [int, float]:
|
| 35 |
+
etas = [etas]*model.model.scheduler.num_inference_steps
|
| 36 |
+
xts = model.sample_xts_from_x0(x0, num_inference_steps=num_inference_steps)
|
| 37 |
+
zs = torch.zeros(size=variance_noise_shape, device=model.device)
|
| 38 |
+
extra_info = [None] * len(zs)
|
| 39 |
+
|
| 40 |
+
if timesteps[0].dtype == torch.int64:
|
| 41 |
+
t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
|
| 42 |
+
elif timesteps[0].dtype == torch.float32:
|
| 43 |
+
t_to_idx = {float(v): k for k, v in enumerate(timesteps)}
|
|
|
|
|
|
|
|
|
|
| 44 |
xt = x0
|
| 45 |
+
op = tqdm(timesteps, desc="Inverting")
|
| 46 |
+
model.setup_extra_inputs(xt, init_timestep=timesteps[0], audio_end_in_s=duration,
|
| 47 |
+
save_compute=save_compute and prompts[0] != "")
|
| 48 |
+
app_op = progress.tqdm(timesteps, desc="Inverting")
|
| 49 |
+
for t, _ in zip(op, app_op):
|
| 50 |
+
idx = num_inference_steps - t_to_idx[int(t) if timesteps[0].dtype == torch.int64 else float(t)] - 1
|
| 51 |
|
|
|
|
|
|
|
|
|
|
| 52 |
# 1. predict noise residual
|
| 53 |
+
xt = xts[idx+1][None]
|
| 54 |
+
xt_inp = model.model.scheduler.scale_model_input(xt, t)
|
| 55 |
|
| 56 |
with torch.no_grad():
|
| 57 |
+
if save_compute and prompts[0] != "":
|
| 58 |
+
comb_out, _, _ = model.unet_forward(
|
| 59 |
+
xt_inp.expand(2, -1, -1, -1) if hasattr(model.model, 'unet') else xt_inp.expand(2, -1, -1),
|
| 60 |
+
timestep=t,
|
| 61 |
+
encoder_hidden_states=torch.cat([uncond_embeddings_hidden_states, text_embeddings_hidden_states
|
| 62 |
+
], dim=0)
|
| 63 |
+
if uncond_embeddings_hidden_states is not None else None,
|
| 64 |
+
class_labels=torch.cat([uncond_embeddings_class_lables, text_embeddings_class_labels], dim=0)
|
| 65 |
+
if uncond_embeddings_class_lables is not None else None,
|
| 66 |
+
encoder_attention_mask=torch.cat([uncond_boolean_prompt_mask, text_embeddings_boolean_prompt_mask
|
| 67 |
+
], dim=0)
|
| 68 |
+
if uncond_boolean_prompt_mask is not None else None,
|
| 69 |
+
)
|
| 70 |
+
out, cond_out = comb_out.sample.chunk(2, dim=0)
|
| 71 |
+
else:
|
| 72 |
+
out = model.unet_forward(xt_inp, timestep=t,
|
| 73 |
+
encoder_hidden_states=uncond_embeddings_hidden_states,
|
| 74 |
+
class_labels=uncond_embeddings_class_lables,
|
| 75 |
+
encoder_attention_mask=uncond_boolean_prompt_mask)[0].sample
|
| 76 |
+
if len(prompts) > 1 or prompts[0] != "":
|
| 77 |
+
cond_out = model.unet_forward(
|
| 78 |
+
xt_inp,
|
| 79 |
+
timestep=t,
|
| 80 |
+
encoder_hidden_states=text_embeddings_hidden_states,
|
| 81 |
+
class_labels=text_embeddings_class_labels,
|
| 82 |
+
encoder_attention_mask=text_embeddings_boolean_prompt_mask)[0].sample
|
| 83 |
|
| 84 |
if len(prompts) > 1 or prompts[0] != "":
|
| 85 |
# # classifier free guidance
|
| 86 |
+
noise_pred = out + (cfg_scales[0] * (cond_out - out)).sum(axis=0).unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
else:
|
| 88 |
+
noise_pred = out
|
| 89 |
+
|
| 90 |
+
# xtm1 = xts[idx+1][None]
|
| 91 |
+
xtm1 = xts[idx][None]
|
| 92 |
+
z, xtm1, extra = model.get_zs_from_xts(xt, xtm1, noise_pred, t,
|
| 93 |
+
eta=etas[idx], numerical_fix=numerical_fix,
|
| 94 |
+
first_order=first_order)
|
| 95 |
+
zs[idx] = z
|
| 96 |
+
# print(f"Fix Xt-1 distance - NORM:{torch.norm(xts[idx] - xtm1):.4g}, MSE:{((xts[idx] - xtm1)**2).mean():.4g}")
|
| 97 |
+
xts[idx] = xtm1
|
| 98 |
+
extra_info[idx] = extra
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
if zs is not None:
|
| 101 |
# zs[-1] = torch.zeros_like(zs[-1])
|
| 102 |
zs[0] = torch.zeros_like(zs[0])
|
| 103 |
# zs_cycle[0] = torch.zeros_like(zs[0])
|
| 104 |
|
| 105 |
+
del app_op.iterables[0]
|
| 106 |
+
return xt, zs, xts, extra_info
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
|
| 109 |
def inversion_reverse_process(model: PipelineWrapper,
|
| 110 |
xT: torch.Tensor,
|
| 111 |
+
tstart: torch.Tensor,
|
|
|
|
| 112 |
etas: float = 0,
|
| 113 |
prompts: List[str] = [""],
|
| 114 |
neg_prompts: List[str] = [""],
|
| 115 |
cfg_scales: Optional[List[float]] = None,
|
|
|
|
| 116 |
zs: Optional[List[torch.Tensor]] = None,
|
| 117 |
+
duration: Optional[float] = None,
|
| 118 |
+
first_order: bool = False,
|
| 119 |
+
extra_info: Optional[List] = None,
|
| 120 |
+
save_compute: bool = True,
|
| 121 |
+
progress=gr.Progress()) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
text_embeddings_hidden_states, text_embeddings_class_labels, \
|
| 124 |
text_embeddings_boolean_prompt_mask = model.encode_text(prompts)
|
| 125 |
+
uncond_embeddings_hidden_states, uncond_embeddings_class_lables, \
|
| 126 |
+
uncond_boolean_prompt_mask = model.encode_text(neg_prompts,
|
| 127 |
+
negative=True,
|
| 128 |
+
save_compute=save_compute,
|
| 129 |
+
cond_length=text_embeddings_class_labels.shape[1]
|
| 130 |
+
if text_embeddings_class_labels is not None else None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
+
xt = xT[tstart.max()].unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
|
| 134 |
if etas is None:
|
| 135 |
etas = 0
|
|
|
|
| 138 |
assert len(etas) == model.model.scheduler.num_inference_steps
|
| 139 |
timesteps = model.model.scheduler.timesteps.to(model.device)
|
| 140 |
|
| 141 |
+
op = tqdm(timesteps[-zs.shape[0]:], desc="Editing")
|
| 142 |
+
if timesteps[0].dtype == torch.int64:
|
| 143 |
+
t_to_idx = {int(v): k for k, v in enumerate(timesteps[-zs.shape[0]:])}
|
| 144 |
+
elif timesteps[0].dtype == torch.float32:
|
| 145 |
+
t_to_idx = {float(v): k for k, v in enumerate(timesteps[-zs.shape[0]:])}
|
| 146 |
+
model.setup_extra_inputs(xt, extra_info=extra_info, init_timestep=timesteps[-zs.shape[0]],
|
| 147 |
+
audio_end_in_s=duration, save_compute=save_compute)
|
| 148 |
+
app_op = progress.tqdm(timesteps[-zs.shape[0]:], desc="Editing")
|
| 149 |
+
for it, (t, _) in enumerate(zip(op, app_op)):
|
| 150 |
+
idx = model.model.scheduler.num_inference_steps - t_to_idx[
|
| 151 |
+
int(t) if timesteps[0].dtype == torch.int64 else float(t)] - \
|
| 152 |
+
(model.model.scheduler.num_inference_steps - zs.shape[0] + 1)
|
| 153 |
+
|
| 154 |
+
xt_inp = model.model.scheduler.scale_model_input(xt, t)
|
| 155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
# # Unconditional embedding
|
| 157 |
with torch.no_grad():
|
| 158 |
+
# print(f'xt_inp.shape: {xt_inp.shape}')
|
| 159 |
+
# print(f't.shape: {t.shape}')
|
| 160 |
+
# print(f'uncond_embeddings_hidden_states.shape: {uncond_embeddings_hidden_states.shape}')
|
| 161 |
+
# print(f'uncond_embeddings_class_lables.shape: {uncond_embeddings_class_lables.shape}')
|
| 162 |
+
# print(f'uncond_boolean_prompt_mask.shape: {uncond_boolean_prompt_mask.shape}')
|
| 163 |
+
# print(f'text_embeddings_hidden_states.shape: {text_embeddings_hidden_states.shape}')
|
| 164 |
+
# print(f'text_embeddings_class_labels.shape: {text_embeddings_class_labels.shape}')
|
| 165 |
+
# print(f'text_embeddings_boolean_prompt_mask.shape: {text_embeddings_boolean_prompt_mask.shape}')
|
| 166 |
+
|
| 167 |
+
if save_compute:
|
| 168 |
+
comb_out, _, _ = model.unet_forward(
|
| 169 |
+
xt_inp.expand(2, -1, -1, -1) if hasattr(model.model, 'unet') else xt_inp.expand(2, -1, -1),
|
| 170 |
+
timestep=t,
|
| 171 |
+
encoder_hidden_states=torch.cat([uncond_embeddings_hidden_states, text_embeddings_hidden_states
|
| 172 |
+
], dim=0)
|
| 173 |
+
if uncond_embeddings_hidden_states is not None else None,
|
| 174 |
+
class_labels=torch.cat([uncond_embeddings_class_lables, text_embeddings_class_labels], dim=0)
|
| 175 |
+
if uncond_embeddings_class_lables is not None else None,
|
| 176 |
+
encoder_attention_mask=torch.cat([uncond_boolean_prompt_mask, text_embeddings_boolean_prompt_mask
|
| 177 |
+
], dim=0)
|
| 178 |
+
if uncond_boolean_prompt_mask is not None else None,
|
| 179 |
+
)
|
| 180 |
+
uncond_out, cond_out = comb_out.sample.chunk(2, dim=0)
|
| 181 |
+
else:
|
| 182 |
+
uncond_out = model.unet_forward(
|
| 183 |
+
xt_inp, timestep=t,
|
| 184 |
+
encoder_hidden_states=uncond_embeddings_hidden_states,
|
| 185 |
+
class_labels=uncond_embeddings_class_lables,
|
| 186 |
+
encoder_attention_mask=uncond_boolean_prompt_mask,
|
| 187 |
+
)[0].sample
|
| 188 |
+
|
| 189 |
+
# Conditional embedding
|
| 190 |
+
cond_out = model.unet_forward(
|
| 191 |
+
xt_inp,
|
| 192 |
timestep=t,
|
| 193 |
encoder_hidden_states=text_embeddings_hidden_states,
|
| 194 |
class_labels=text_embeddings_class_labels,
|
| 195 |
encoder_attention_mask=text_embeddings_boolean_prompt_mask,
|
| 196 |
+
)[0].sample
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
|
| 198 |
z = zs[idx] if zs is not None else None
|
|
|
|
|
|
|
| 199 |
z = z.unsqueeze(0)
|
| 200 |
+
# classifier free guidance
|
| 201 |
+
noise_pred = uncond_out + (cfg_scales[0] * (cond_out - uncond_out)).sum(axis=0).unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
|
| 203 |
# 2. compute less noisy image and set x_t -> x_t-1
|
| 204 |
+
xt = model.reverse_step_with_custom_noise(noise_pred, t, xt, variance_noise=z,
|
| 205 |
+
eta=etas[idx], first_order=first_order)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
|
| 207 |
+
del app_op.iterables[0]
|
| 208 |
return xt, zs
|
models.py
CHANGED
|
@@ -1,46 +1,160 @@
|
|
| 1 |
import torch
|
| 2 |
-
from diffusers import DDIMScheduler
|
| 3 |
-
from diffusers import
|
| 4 |
-
from
|
|
|
|
| 5 |
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
|
|
|
|
| 6 |
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
class PipelineWrapper(torch.nn.Module):
|
| 10 |
-
def __init__(self, model_id
|
|
|
|
|
|
|
|
|
|
| 11 |
super().__init__(*args, **kwargs)
|
| 12 |
self.model_id = model_id
|
| 13 |
self.device = device
|
| 14 |
self.double_precision = double_precision
|
|
|
|
| 15 |
|
| 16 |
-
def get_sigma(self, timestep) -> float:
|
| 17 |
sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.model.scheduler.alphas_cumprod - 1)
|
| 18 |
return sqrt_recipm1_alphas_cumprod[timestep]
|
| 19 |
|
| 20 |
-
def load_scheduler(self):
|
| 21 |
pass
|
| 22 |
|
| 23 |
-
def get_fn_STFT(self):
|
| 24 |
pass
|
| 25 |
|
| 26 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
pass
|
| 28 |
|
| 29 |
-
def
|
| 30 |
pass
|
| 31 |
|
| 32 |
-
def
|
| 33 |
pass
|
| 34 |
|
| 35 |
-
def encode_text(self, prompts: List[str]
|
|
|
|
| 36 |
pass
|
| 37 |
|
| 38 |
-
def get_variance(self, timestep, prev_timestep):
|
| 39 |
pass
|
| 40 |
|
| 41 |
-
def get_alpha_prod_t_prev(self, prev_timestep):
|
| 42 |
pass
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
def unet_forward(self,
|
| 45 |
sample: torch.FloatTensor,
|
| 46 |
timestep: Union[torch.Tensor, float, int],
|
|
@@ -57,244 +171,27 @@ class PipelineWrapper(torch.nn.Module):
|
|
| 57 |
replace_skip_conns: Optional[Dict[int, torch.Tensor]] = None,
|
| 58 |
return_dict: bool = True,
|
| 59 |
zero_out_resconns: Optional[Union[int, List]] = None) -> Tuple:
|
| 60 |
-
|
| 61 |
-
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
| 62 |
-
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
|
| 63 |
-
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
| 64 |
-
# on the fly if necessary.
|
| 65 |
-
default_overall_up_factor = 2**self.model.unet.num_upsamplers
|
| 66 |
-
|
| 67 |
-
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
| 68 |
-
forward_upsample_size = False
|
| 69 |
-
upsample_size = None
|
| 70 |
-
|
| 71 |
-
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
| 72 |
-
# logger.info("Forward upsample size to force interpolation output size.")
|
| 73 |
-
forward_upsample_size = True
|
| 74 |
-
|
| 75 |
-
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
|
| 76 |
-
# expects mask of shape:
|
| 77 |
-
# [batch, key_tokens]
|
| 78 |
-
# adds singleton query_tokens dimension:
|
| 79 |
-
# [batch, 1, key_tokens]
|
| 80 |
-
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
| 81 |
-
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
| 82 |
-
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
| 83 |
-
if attention_mask is not None:
|
| 84 |
-
# assume that mask is expressed as:
|
| 85 |
-
# (1 = keep, 0 = discard)
|
| 86 |
-
# convert mask into a bias that can be added to attention scores:
|
| 87 |
-
# (keep = +0, discard = -10000.0)
|
| 88 |
-
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
| 89 |
-
attention_mask = attention_mask.unsqueeze(1)
|
| 90 |
-
|
| 91 |
-
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
| 92 |
-
if encoder_attention_mask is not None:
|
| 93 |
-
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
|
| 94 |
-
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
| 95 |
-
|
| 96 |
-
# 0. center input if necessary
|
| 97 |
-
if self.model.unet.config.center_input_sample:
|
| 98 |
-
sample = 2 * sample - 1.0
|
| 99 |
-
|
| 100 |
-
# 1. time
|
| 101 |
-
timesteps = timestep
|
| 102 |
-
if not torch.is_tensor(timesteps):
|
| 103 |
-
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
| 104 |
-
# This would be a good case for the `match` statement (Python 3.10+)
|
| 105 |
-
is_mps = sample.device.type == "mps"
|
| 106 |
-
if isinstance(timestep, float):
|
| 107 |
-
dtype = torch.float32 if is_mps else torch.float64
|
| 108 |
-
else:
|
| 109 |
-
dtype = torch.int32 if is_mps else torch.int64
|
| 110 |
-
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
| 111 |
-
elif len(timesteps.shape) == 0:
|
| 112 |
-
timesteps = timesteps[None].to(sample.device)
|
| 113 |
-
|
| 114 |
-
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 115 |
-
timesteps = timesteps.expand(sample.shape[0])
|
| 116 |
-
|
| 117 |
-
t_emb = self.model.unet.time_proj(timesteps)
|
| 118 |
-
|
| 119 |
-
# `Timesteps` does not contain any weights and will always return f32 tensors
|
| 120 |
-
# but time_embedding might actually be running in fp16. so we need to cast here.
|
| 121 |
-
# there might be better ways to encapsulate this.
|
| 122 |
-
t_emb = t_emb.to(dtype=sample.dtype)
|
| 123 |
-
|
| 124 |
-
emb = self.model.unet.time_embedding(t_emb, timestep_cond)
|
| 125 |
-
|
| 126 |
-
if self.model.unet.class_embedding is not None:
|
| 127 |
-
if class_labels is None:
|
| 128 |
-
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
| 129 |
-
|
| 130 |
-
if self.model.unet.config.class_embed_type == "timestep":
|
| 131 |
-
class_labels = self.model.unet.time_proj(class_labels)
|
| 132 |
-
|
| 133 |
-
# `Timesteps` does not contain any weights and will always return f32 tensors
|
| 134 |
-
# there might be better ways to encapsulate this.
|
| 135 |
-
class_labels = class_labels.to(dtype=sample.dtype)
|
| 136 |
-
|
| 137 |
-
class_emb = self.model.unet.class_embedding(class_labels).to(dtype=sample.dtype)
|
| 138 |
-
|
| 139 |
-
if self.model.unet.config.class_embeddings_concat:
|
| 140 |
-
emb = torch.cat([emb, class_emb], dim=-1)
|
| 141 |
-
else:
|
| 142 |
-
emb = emb + class_emb
|
| 143 |
-
|
| 144 |
-
if self.model.unet.config.addition_embed_type == "text":
|
| 145 |
-
aug_emb = self.model.unet.add_embedding(encoder_hidden_states)
|
| 146 |
-
emb = emb + aug_emb
|
| 147 |
-
elif self.model.unet.config.addition_embed_type == "text_image":
|
| 148 |
-
# Kadinsky 2.1 - style
|
| 149 |
-
if "image_embeds" not in added_cond_kwargs:
|
| 150 |
-
raise ValueError(
|
| 151 |
-
f"{self.model.unet.__class__} has the config param `addition_embed_type` set to 'text_image' "
|
| 152 |
-
f"which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
| 153 |
-
)
|
| 154 |
-
|
| 155 |
-
image_embs = added_cond_kwargs.get("image_embeds")
|
| 156 |
-
text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
|
| 157 |
-
|
| 158 |
-
aug_emb = self.model.unet.add_embedding(text_embs, image_embs)
|
| 159 |
-
emb = emb + aug_emb
|
| 160 |
-
|
| 161 |
-
if self.model.unet.time_embed_act is not None:
|
| 162 |
-
emb = self.model.unet.time_embed_act(emb)
|
| 163 |
-
|
| 164 |
-
if self.model.unet.encoder_hid_proj is not None and self.model.unet.config.encoder_hid_dim_type == "text_proj":
|
| 165 |
-
encoder_hidden_states = self.model.unet.encoder_hid_proj(encoder_hidden_states)
|
| 166 |
-
elif self.model.unet.encoder_hid_proj is not None and \
|
| 167 |
-
self.model.unet.config.encoder_hid_dim_type == "text_image_proj":
|
| 168 |
-
# Kadinsky 2.1 - style
|
| 169 |
-
if "image_embeds" not in added_cond_kwargs:
|
| 170 |
-
raise ValueError(
|
| 171 |
-
f"{self.model.unet.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' "
|
| 172 |
-
f"which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
| 173 |
-
)
|
| 174 |
-
|
| 175 |
-
image_embeds = added_cond_kwargs.get("image_embeds")
|
| 176 |
-
encoder_hidden_states = self.model.unet.encoder_hid_proj(encoder_hidden_states, image_embeds)
|
| 177 |
-
|
| 178 |
-
# 2. pre-process
|
| 179 |
-
sample = self.model.unet.conv_in(sample)
|
| 180 |
-
|
| 181 |
-
# 3. down
|
| 182 |
-
down_block_res_samples = (sample,)
|
| 183 |
-
for downsample_block in self.model.unet.down_blocks:
|
| 184 |
-
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
| 185 |
-
sample, res_samples = downsample_block(
|
| 186 |
-
hidden_states=sample,
|
| 187 |
-
temb=emb,
|
| 188 |
-
encoder_hidden_states=encoder_hidden_states,
|
| 189 |
-
attention_mask=attention_mask,
|
| 190 |
-
cross_attention_kwargs=cross_attention_kwargs,
|
| 191 |
-
encoder_attention_mask=encoder_attention_mask,
|
| 192 |
-
)
|
| 193 |
-
else:
|
| 194 |
-
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
| 195 |
-
|
| 196 |
-
down_block_res_samples += res_samples
|
| 197 |
-
|
| 198 |
-
if down_block_additional_residuals is not None:
|
| 199 |
-
new_down_block_res_samples = ()
|
| 200 |
-
|
| 201 |
-
for down_block_res_sample, down_block_additional_residual in zip(
|
| 202 |
-
down_block_res_samples, down_block_additional_residuals
|
| 203 |
-
):
|
| 204 |
-
down_block_res_sample = down_block_res_sample + down_block_additional_residual
|
| 205 |
-
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
|
| 206 |
-
|
| 207 |
-
down_block_res_samples = new_down_block_res_samples
|
| 208 |
-
|
| 209 |
-
# 4. mid
|
| 210 |
-
if self.model.unet.mid_block is not None:
|
| 211 |
-
sample = self.model.unet.mid_block(
|
| 212 |
-
sample,
|
| 213 |
-
emb,
|
| 214 |
-
encoder_hidden_states=encoder_hidden_states,
|
| 215 |
-
attention_mask=attention_mask,
|
| 216 |
-
cross_attention_kwargs=cross_attention_kwargs,
|
| 217 |
-
encoder_attention_mask=encoder_attention_mask,
|
| 218 |
-
)
|
| 219 |
-
|
| 220 |
-
# print(sample.shape)
|
| 221 |
-
|
| 222 |
-
if replace_h_space is None:
|
| 223 |
-
h_space = sample.clone()
|
| 224 |
-
else:
|
| 225 |
-
h_space = replace_h_space
|
| 226 |
-
sample = replace_h_space.clone()
|
| 227 |
-
|
| 228 |
-
if mid_block_additional_residual is not None:
|
| 229 |
-
sample = sample + mid_block_additional_residual
|
| 230 |
-
|
| 231 |
-
extracted_res_conns = {}
|
| 232 |
-
# 5. up
|
| 233 |
-
for i, upsample_block in enumerate(self.model.unet.up_blocks):
|
| 234 |
-
is_final_block = i == len(self.model.unet.up_blocks) - 1
|
| 235 |
-
|
| 236 |
-
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
|
| 237 |
-
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
| 238 |
-
if replace_skip_conns is not None and replace_skip_conns.get(i):
|
| 239 |
-
res_samples = replace_skip_conns.get(i)
|
| 240 |
-
|
| 241 |
-
if zero_out_resconns is not None:
|
| 242 |
-
if (type(zero_out_resconns) is int and i >= (zero_out_resconns - 1)) or \
|
| 243 |
-
type(zero_out_resconns) is list and i in zero_out_resconns:
|
| 244 |
-
res_samples = [torch.zeros_like(x) for x in res_samples]
|
| 245 |
-
# down_block_res_samples = [torch.zeros_like(x) for x in down_block_res_samples]
|
| 246 |
-
|
| 247 |
-
extracted_res_conns[i] = res_samples
|
| 248 |
-
|
| 249 |
-
# if we have not reached the final block and need to forward the
|
| 250 |
-
# upsample size, we do it here
|
| 251 |
-
if not is_final_block and forward_upsample_size:
|
| 252 |
-
upsample_size = down_block_res_samples[-1].shape[2:]
|
| 253 |
-
|
| 254 |
-
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
| 255 |
-
sample = upsample_block(
|
| 256 |
-
hidden_states=sample,
|
| 257 |
-
temb=emb,
|
| 258 |
-
res_hidden_states_tuple=res_samples,
|
| 259 |
-
encoder_hidden_states=encoder_hidden_states,
|
| 260 |
-
cross_attention_kwargs=cross_attention_kwargs,
|
| 261 |
-
upsample_size=upsample_size,
|
| 262 |
-
attention_mask=attention_mask,
|
| 263 |
-
encoder_attention_mask=encoder_attention_mask,
|
| 264 |
-
)
|
| 265 |
-
else:
|
| 266 |
-
sample = upsample_block(
|
| 267 |
-
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
|
| 268 |
-
)
|
| 269 |
-
|
| 270 |
-
# 6. post-process
|
| 271 |
-
if self.model.unet.conv_norm_out:
|
| 272 |
-
sample = self.model.unet.conv_norm_out(sample)
|
| 273 |
-
sample = self.model.unet.conv_act(sample)
|
| 274 |
-
sample = self.model.unet.conv_out(sample)
|
| 275 |
-
|
| 276 |
-
if not return_dict:
|
| 277 |
-
return (sample,)
|
| 278 |
-
|
| 279 |
-
return UNet2DConditionOutput(sample=sample), h_space, extracted_res_conns
|
| 280 |
|
| 281 |
|
| 282 |
class AudioLDM2Wrapper(PipelineWrapper):
|
| 283 |
def __init__(self, *args, **kwargs) -> None:
|
| 284 |
super().__init__(*args, **kwargs)
|
| 285 |
if self.double_precision:
|
| 286 |
-
self.model = AudioLDM2Pipeline.from_pretrained(self.model_id, torch_dtype=torch.float64
|
|
|
|
| 287 |
else:
|
| 288 |
try:
|
| 289 |
-
self.model = AudioLDM2Pipeline.from_pretrained(self.model_id, local_files_only=True
|
|
|
|
| 290 |
except FileNotFoundError:
|
| 291 |
-
self.model = AudioLDM2Pipeline.from_pretrained(self.model_id, local_files_only=False
|
|
|
|
| 292 |
|
| 293 |
-
def load_scheduler(self):
|
| 294 |
-
# self.model.scheduler = DDIMScheduler.from_config(self.model_id, subfolder="scheduler")
|
| 295 |
self.model.scheduler = DDIMScheduler.from_pretrained(self.model_id, subfolder="scheduler")
|
| 296 |
|
| 297 |
-
def get_fn_STFT(self):
|
| 298 |
from audioldm.audio import TacotronSTFT
|
| 299 |
return TacotronSTFT(
|
| 300 |
filter_length=1024,
|
|
@@ -306,17 +203,17 @@ class AudioLDM2Wrapper(PipelineWrapper):
|
|
| 306 |
mel_fmax=8000,
|
| 307 |
)
|
| 308 |
|
| 309 |
-
def vae_encode(self, x):
|
| 310 |
# self.model.vae.disable_tiling()
|
| 311 |
if x.shape[2] % 4:
|
| 312 |
x = torch.nn.functional.pad(x, (0, 0, 4 - (x.shape[2] % 4), 0))
|
| 313 |
return (self.model.vae.encode(x).latent_dist.mode() * self.model.vae.config.scaling_factor).float()
|
| 314 |
# return (self.encode_no_tiling(x).latent_dist.mode() * self.model.vae.config.scaling_factor).float()
|
| 315 |
|
| 316 |
-
def vae_decode(self, x):
|
| 317 |
return self.model.vae.decode(1 / self.model.vae.config.scaling_factor * x).sample
|
| 318 |
|
| 319 |
-
def decode_to_mel(self, x):
|
| 320 |
if self.double_precision:
|
| 321 |
tmp = self.model.mel_spectrogram_to_waveform(x[:, 0].detach().double()).detach()
|
| 322 |
tmp = self.model.mel_spectrogram_to_waveform(x[:, 0].detach().float()).detach()
|
|
@@ -324,7 +221,9 @@ class AudioLDM2Wrapper(PipelineWrapper):
|
|
| 324 |
tmp = tmp.unsqueeze(0)
|
| 325 |
return tmp
|
| 326 |
|
| 327 |
-
def encode_text(self, prompts: List[str]
|
|
|
|
|
|
|
| 328 |
tokenizers = [self.model.tokenizer, self.model.tokenizer_2]
|
| 329 |
text_encoders = [self.model.text_encoder, self.model.text_encoder_2]
|
| 330 |
prompt_embeds_list = []
|
|
@@ -333,8 +232,11 @@ class AudioLDM2Wrapper(PipelineWrapper):
|
|
| 333 |
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
|
| 334 |
text_inputs = tokenizer(
|
| 335 |
prompts,
|
| 336 |
-
padding="max_length" if isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast))
|
| 337 |
-
|
|
|
|
|
|
|
|
|
|
| 338 |
truncation=True,
|
| 339 |
return_tensors="pt",
|
| 340 |
)
|
|
@@ -404,7 +306,7 @@ class AudioLDM2Wrapper(PipelineWrapper):
|
|
| 404 |
|
| 405 |
return generated_prompt_embeds, prompt_embeds, attention_mask
|
| 406 |
|
| 407 |
-
def get_variance(self, timestep, prev_timestep):
|
| 408 |
alpha_prod_t = self.model.scheduler.alphas_cumprod[timestep]
|
| 409 |
alpha_prod_t_prev = self.get_alpha_prod_t_prev(prev_timestep)
|
| 410 |
beta_prod_t = 1 - alpha_prod_t
|
|
@@ -412,7 +314,7 @@ class AudioLDM2Wrapper(PipelineWrapper):
|
|
| 412 |
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
|
| 413 |
return variance
|
| 414 |
|
| 415 |
-
def get_alpha_prod_t_prev(self, prev_timestep):
|
| 416 |
return self.model.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 \
|
| 417 |
else self.model.scheduler.final_alpha_cumprod
|
| 418 |
|
|
@@ -485,8 +387,6 @@ class AudioLDM2Wrapper(PipelineWrapper):
|
|
| 485 |
# 1. time
|
| 486 |
timesteps = timestep
|
| 487 |
if not torch.is_tensor(timesteps):
|
| 488 |
-
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
| 489 |
-
# This would be a good case for the `match` statement (Python 3.10+)
|
| 490 |
is_mps = sample.device.type == "mps"
|
| 491 |
if isinstance(timestep, float):
|
| 492 |
dtype = torch.float32 if is_mps else torch.float64
|
|
@@ -628,12 +528,328 @@ class AudioLDM2Wrapper(PipelineWrapper):
|
|
| 628 |
|
| 629 |
return UNet2DConditionOutput(sample=sample), h_space, extracted_res_conns
|
| 630 |
|
| 631 |
-
|
| 632 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 633 |
|
| 634 |
|
| 635 |
-
def load_model(model_id, device,
|
| 636 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 637 |
ldm_stable.load_scheduler()
|
| 638 |
torch.cuda.empty_cache()
|
| 639 |
return ldm_stable
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from diffusers import DDIMScheduler, CosineDPMSolverMultistepScheduler
|
| 3 |
+
from diffusers.schedulers.scheduling_dpmsolver_sde import BrownianTreeNoiseSampler
|
| 4 |
+
from diffusers import AudioLDM2Pipeline, StableAudioPipeline
|
| 5 |
+
from transformers import RobertaTokenizer, RobertaTokenizerFast, VitsTokenizer
|
| 6 |
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
|
| 7 |
+
from diffusers.models.embeddings import get_1d_rotary_pos_embed
|
| 8 |
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 9 |
+
import gradio as gr
|
| 10 |
|
| 11 |
|
| 12 |
class PipelineWrapper(torch.nn.Module):
|
| 13 |
+
def __init__(self, model_id: str,
|
| 14 |
+
device: torch.device,
|
| 15 |
+
double_precision: bool = False,
|
| 16 |
+
token: Optional[str] = None, *args, **kwargs) -> None:
|
| 17 |
super().__init__(*args, **kwargs)
|
| 18 |
self.model_id = model_id
|
| 19 |
self.device = device
|
| 20 |
self.double_precision = double_precision
|
| 21 |
+
self.token = token
|
| 22 |
|
| 23 |
+
def get_sigma(self, timestep: int) -> float:
|
| 24 |
sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.model.scheduler.alphas_cumprod - 1)
|
| 25 |
return sqrt_recipm1_alphas_cumprod[timestep]
|
| 26 |
|
| 27 |
+
def load_scheduler(self) -> None:
|
| 28 |
pass
|
| 29 |
|
| 30 |
+
def get_fn_STFT(self) -> torch.nn.Module:
|
| 31 |
pass
|
| 32 |
|
| 33 |
+
def get_sr(self) -> int:
|
| 34 |
+
return 16000
|
| 35 |
+
|
| 36 |
+
def vae_encode(self, x: torch.Tensor) -> torch.Tensor:
|
| 37 |
+
pass
|
| 38 |
+
|
| 39 |
+
def vae_decode(self, x: torch.Tensor) -> torch.Tensor:
|
| 40 |
pass
|
| 41 |
|
| 42 |
+
def decode_to_mel(self, x: torch.Tensor) -> torch.Tensor:
|
| 43 |
pass
|
| 44 |
|
| 45 |
+
def setup_extra_inputs(self, *args, **kwargs) -> None:
|
| 46 |
pass
|
| 47 |
|
| 48 |
+
def encode_text(self, prompts: List[str], **kwargs
|
| 49 |
+
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
| 50 |
pass
|
| 51 |
|
| 52 |
+
def get_variance(self, timestep: torch.Tensor, prev_timestep: torch.Tensor) -> torch.Tensor:
|
| 53 |
pass
|
| 54 |
|
| 55 |
+
def get_alpha_prod_t_prev(self, prev_timestep: torch.Tensor) -> torch.Tensor:
|
| 56 |
pass
|
| 57 |
|
| 58 |
+
def get_noise_shape(self, x0: torch.Tensor, num_steps: int) -> Tuple[int, ...]:
|
| 59 |
+
variance_noise_shape = (num_steps,
|
| 60 |
+
self.model.unet.config.in_channels,
|
| 61 |
+
x0.shape[-2],
|
| 62 |
+
x0.shape[-1])
|
| 63 |
+
return variance_noise_shape
|
| 64 |
+
|
| 65 |
+
def sample_xts_from_x0(self, x0: torch.Tensor, num_inference_steps: int = 50) -> torch.Tensor:
|
| 66 |
+
"""
|
| 67 |
+
Samples from P(x_1:T|x_0)
|
| 68 |
+
"""
|
| 69 |
+
alpha_bar = self.model.scheduler.alphas_cumprod
|
| 70 |
+
sqrt_one_minus_alpha_bar = (1-alpha_bar) ** 0.5
|
| 71 |
+
|
| 72 |
+
variance_noise_shape = self.get_noise_shape(x0, num_inference_steps + 1)
|
| 73 |
+
timesteps = self.model.scheduler.timesteps.to(self.device)
|
| 74 |
+
t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
|
| 75 |
+
xts = torch.zeros(variance_noise_shape).to(x0.device)
|
| 76 |
+
xts[0] = x0
|
| 77 |
+
for t in reversed(timesteps):
|
| 78 |
+
idx = num_inference_steps - t_to_idx[int(t)]
|
| 79 |
+
xts[idx] = x0 * (alpha_bar[t] ** 0.5) + torch.randn_like(x0) * sqrt_one_minus_alpha_bar[t]
|
| 80 |
+
|
| 81 |
+
return xts
|
| 82 |
+
|
| 83 |
+
def get_zs_from_xts(self, xt: torch.Tensor, xtm1: torch.Tensor, noise_pred: torch.Tensor,
|
| 84 |
+
t: torch.Tensor, eta: float = 0, numerical_fix: bool = True, **kwargs
|
| 85 |
+
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
| 86 |
+
# pred of x0
|
| 87 |
+
alpha_bar = self.model.scheduler.alphas_cumprod
|
| 88 |
+
if self.model.scheduler.config.prediction_type == 'epsilon':
|
| 89 |
+
pred_original_sample = (xt - (1 - alpha_bar[t]) ** 0.5 * noise_pred) / alpha_bar[t] ** 0.5
|
| 90 |
+
elif self.model.scheduler.config.prediction_type == 'v_prediction':
|
| 91 |
+
pred_original_sample = (alpha_bar[t] ** 0.5) * xt - ((1 - alpha_bar[t]) ** 0.5) * noise_pred
|
| 92 |
+
|
| 93 |
+
# direction to xt
|
| 94 |
+
prev_timestep = t - self.model.scheduler.config.num_train_timesteps // \
|
| 95 |
+
self.model.scheduler.num_inference_steps
|
| 96 |
+
|
| 97 |
+
alpha_prod_t_prev = self.get_alpha_prod_t_prev(prev_timestep)
|
| 98 |
+
variance = self.get_variance(t, prev_timestep)
|
| 99 |
+
|
| 100 |
+
if self.model.scheduler.config.prediction_type == 'epsilon':
|
| 101 |
+
radom_noise_pred = noise_pred
|
| 102 |
+
elif self.model.scheduler.config.prediction_type == 'v_prediction':
|
| 103 |
+
radom_noise_pred = (alpha_bar[t] ** 0.5) * noise_pred + ((1 - alpha_bar[t]) ** 0.5) * xt
|
| 104 |
+
|
| 105 |
+
pred_sample_direction = (1 - alpha_prod_t_prev - eta * variance) ** (0.5) * radom_noise_pred
|
| 106 |
+
|
| 107 |
+
mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
|
| 108 |
+
|
| 109 |
+
z = (xtm1 - mu_xt) / (eta * variance ** 0.5)
|
| 110 |
+
|
| 111 |
+
# correction to avoid error accumulation
|
| 112 |
+
if numerical_fix:
|
| 113 |
+
xtm1 = mu_xt + (eta * variance ** 0.5)*z
|
| 114 |
+
|
| 115 |
+
return z, xtm1, None
|
| 116 |
+
|
| 117 |
+
def reverse_step_with_custom_noise(self, model_output: torch.Tensor, timestep: torch.Tensor, sample: torch.Tensor,
|
| 118 |
+
variance_noise: Optional[torch.Tensor] = None, eta: float = 0, **kwargs
|
| 119 |
+
) -> torch.Tensor:
|
| 120 |
+
# 1. get previous step value (=t-1)
|
| 121 |
+
prev_timestep = timestep - self.model.scheduler.config.num_train_timesteps // \
|
| 122 |
+
self.model.scheduler.num_inference_steps
|
| 123 |
+
# 2. compute alphas, betas
|
| 124 |
+
alpha_prod_t = self.model.scheduler.alphas_cumprod[timestep]
|
| 125 |
+
alpha_prod_t_prev = self.get_alpha_prod_t_prev(prev_timestep)
|
| 126 |
+
beta_prod_t = 1 - alpha_prod_t
|
| 127 |
+
# 3. compute predicted original sample from predicted noise also called
|
| 128 |
+
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
| 129 |
+
if self.model.scheduler.config.prediction_type == 'epsilon':
|
| 130 |
+
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
| 131 |
+
elif self.model.scheduler.config.prediction_type == 'v_prediction':
|
| 132 |
+
pred_original_sample = (alpha_prod_t ** 0.5) * sample - (beta_prod_t ** 0.5) * model_output
|
| 133 |
+
|
| 134 |
+
# 5. compute variance: "sigma_t(η)" -> see formula (16)
|
| 135 |
+
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
|
| 136 |
+
# variance = self.scheduler._get_variance(timestep, prev_timestep)
|
| 137 |
+
variance = self.get_variance(timestep, prev_timestep)
|
| 138 |
+
# std_dev_t = eta * variance ** (0.5)
|
| 139 |
+
# Take care of asymetric reverse process (asyrp)
|
| 140 |
+
if self.model.scheduler.config.prediction_type == 'epsilon':
|
| 141 |
+
model_output_direction = model_output
|
| 142 |
+
elif self.model.scheduler.config.prediction_type == 'v_prediction':
|
| 143 |
+
model_output_direction = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
|
| 144 |
+
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
| 145 |
+
# pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output_direction
|
| 146 |
+
pred_sample_direction = (1 - alpha_prod_t_prev - eta * variance) ** (0.5) * model_output_direction
|
| 147 |
+
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
| 148 |
+
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
|
| 149 |
+
# 8. Add noice if eta > 0
|
| 150 |
+
if eta > 0:
|
| 151 |
+
if variance_noise is None:
|
| 152 |
+
variance_noise = torch.randn(model_output.shape, device=self.device)
|
| 153 |
+
sigma_z = eta * variance ** (0.5) * variance_noise
|
| 154 |
+
prev_sample = prev_sample + sigma_z
|
| 155 |
+
|
| 156 |
+
return prev_sample
|
| 157 |
+
|
| 158 |
def unet_forward(self,
|
| 159 |
sample: torch.FloatTensor,
|
| 160 |
timestep: Union[torch.Tensor, float, int],
|
|
|
|
| 171 |
replace_skip_conns: Optional[Dict[int, torch.Tensor]] = None,
|
| 172 |
return_dict: bool = True,
|
| 173 |
zero_out_resconns: Optional[Union[int, List]] = None) -> Tuple:
|
| 174 |
+
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
|
| 176 |
|
| 177 |
class AudioLDM2Wrapper(PipelineWrapper):
|
| 178 |
def __init__(self, *args, **kwargs) -> None:
|
| 179 |
super().__init__(*args, **kwargs)
|
| 180 |
if self.double_precision:
|
| 181 |
+
self.model = AudioLDM2Pipeline.from_pretrained(self.model_id, torch_dtype=torch.float64, token=self.token
|
| 182 |
+
).to(self.device)
|
| 183 |
else:
|
| 184 |
try:
|
| 185 |
+
self.model = AudioLDM2Pipeline.from_pretrained(self.model_id, local_files_only=True, token=self.token
|
| 186 |
+
).to(self.device)
|
| 187 |
except FileNotFoundError:
|
| 188 |
+
self.model = AudioLDM2Pipeline.from_pretrained(self.model_id, local_files_only=False, token=self.token
|
| 189 |
+
).to(self.device)
|
| 190 |
|
| 191 |
+
def load_scheduler(self) -> None:
|
|
|
|
| 192 |
self.model.scheduler = DDIMScheduler.from_pretrained(self.model_id, subfolder="scheduler")
|
| 193 |
|
| 194 |
+
def get_fn_STFT(self) -> torch.nn.Module:
|
| 195 |
from audioldm.audio import TacotronSTFT
|
| 196 |
return TacotronSTFT(
|
| 197 |
filter_length=1024,
|
|
|
|
| 203 |
mel_fmax=8000,
|
| 204 |
)
|
| 205 |
|
| 206 |
+
def vae_encode(self, x: torch.Tensor) -> torch.Tensor:
|
| 207 |
# self.model.vae.disable_tiling()
|
| 208 |
if x.shape[2] % 4:
|
| 209 |
x = torch.nn.functional.pad(x, (0, 0, 4 - (x.shape[2] % 4), 0))
|
| 210 |
return (self.model.vae.encode(x).latent_dist.mode() * self.model.vae.config.scaling_factor).float()
|
| 211 |
# return (self.encode_no_tiling(x).latent_dist.mode() * self.model.vae.config.scaling_factor).float()
|
| 212 |
|
| 213 |
+
def vae_decode(self, x: torch.Tensor) -> torch.Tensor:
|
| 214 |
return self.model.vae.decode(1 / self.model.vae.config.scaling_factor * x).sample
|
| 215 |
|
| 216 |
+
def decode_to_mel(self, x: torch.Tensor) -> torch.Tensor:
|
| 217 |
if self.double_precision:
|
| 218 |
tmp = self.model.mel_spectrogram_to_waveform(x[:, 0].detach().double()).detach()
|
| 219 |
tmp = self.model.mel_spectrogram_to_waveform(x[:, 0].detach().float()).detach()
|
|
|
|
| 221 |
tmp = tmp.unsqueeze(0)
|
| 222 |
return tmp
|
| 223 |
|
| 224 |
+
def encode_text(self, prompts: List[str], negative: bool = False,
|
| 225 |
+
save_compute: bool = False, cond_length: int = 0, **kwargs
|
| 226 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 227 |
tokenizers = [self.model.tokenizer, self.model.tokenizer_2]
|
| 228 |
text_encoders = [self.model.text_encoder, self.model.text_encoder_2]
|
| 229 |
prompt_embeds_list = []
|
|
|
|
| 232 |
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
|
| 233 |
text_inputs = tokenizer(
|
| 234 |
prompts,
|
| 235 |
+
padding="max_length" if (save_compute and negative) or isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast))
|
| 236 |
+
else True,
|
| 237 |
+
max_length=tokenizer.model_max_length
|
| 238 |
+
if (not save_compute) or ((not negative) or isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast, VitsTokenizer)))
|
| 239 |
+
else cond_length,
|
| 240 |
truncation=True,
|
| 241 |
return_tensors="pt",
|
| 242 |
)
|
|
|
|
| 306 |
|
| 307 |
return generated_prompt_embeds, prompt_embeds, attention_mask
|
| 308 |
|
| 309 |
+
def get_variance(self, timestep: torch.Tensor, prev_timestep: torch.Tensor) -> torch.Tensor:
|
| 310 |
alpha_prod_t = self.model.scheduler.alphas_cumprod[timestep]
|
| 311 |
alpha_prod_t_prev = self.get_alpha_prod_t_prev(prev_timestep)
|
| 312 |
beta_prod_t = 1 - alpha_prod_t
|
|
|
|
| 314 |
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
|
| 315 |
return variance
|
| 316 |
|
| 317 |
+
def get_alpha_prod_t_prev(self, prev_timestep: torch.Tensor) -> torch.Tensor:
|
| 318 |
return self.model.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 \
|
| 319 |
else self.model.scheduler.final_alpha_cumprod
|
| 320 |
|
|
|
|
| 387 |
# 1. time
|
| 388 |
timesteps = timestep
|
| 389 |
if not torch.is_tensor(timesteps):
|
|
|
|
|
|
|
| 390 |
is_mps = sample.device.type == "mps"
|
| 391 |
if isinstance(timestep, float):
|
| 392 |
dtype = torch.float32 if is_mps else torch.float64
|
|
|
|
| 528 |
|
| 529 |
return UNet2DConditionOutput(sample=sample), h_space, extracted_res_conns
|
| 530 |
|
| 531 |
+
|
| 532 |
+
class StableAudWrapper(PipelineWrapper):
|
| 533 |
+
def __init__(self, *args, **kwargs) -> None:
|
| 534 |
+
super().__init__(*args, **kwargs)
|
| 535 |
+
try:
|
| 536 |
+
self.model = StableAudioPipeline.from_pretrained(self.model_id, token=self.token, local_files_only=True
|
| 537 |
+
).to(self.device)
|
| 538 |
+
except FileNotFoundError:
|
| 539 |
+
self.model = StableAudioPipeline.from_pretrained(self.model_id, token=self.token, local_files_only=False
|
| 540 |
+
).to(self.device)
|
| 541 |
+
self.model.transformer.eval()
|
| 542 |
+
self.model.vae.eval()
|
| 543 |
+
|
| 544 |
+
if self.double_precision:
|
| 545 |
+
self.model = self.model.to(torch.float64)
|
| 546 |
+
|
| 547 |
+
def load_scheduler(self) -> None:
|
| 548 |
+
self.model.scheduler = CosineDPMSolverMultistepScheduler.from_pretrained(
|
| 549 |
+
self.model_id, subfolder="scheduler", token=self.token)
|
| 550 |
+
|
| 551 |
+
def encode_text(self, prompts: List[str], negative: bool = False, **kwargs) -> Tuple[torch.Tensor, None, torch.Tensor]:
|
| 552 |
+
text_inputs = self.model.tokenizer(
|
| 553 |
+
prompts,
|
| 554 |
+
padding="max_length",
|
| 555 |
+
max_length=self.model.tokenizer.model_max_length,
|
| 556 |
+
truncation=True,
|
| 557 |
+
return_tensors="pt",
|
| 558 |
+
)
|
| 559 |
+
|
| 560 |
+
text_input_ids = text_inputs.input_ids.to(self.device)
|
| 561 |
+
attention_mask = text_inputs.attention_mask.to(self.device)
|
| 562 |
+
|
| 563 |
+
self.model.text_encoder.eval()
|
| 564 |
+
with torch.no_grad():
|
| 565 |
+
prompt_embeds = self.model.text_encoder(text_input_ids, attention_mask=attention_mask)[0]
|
| 566 |
+
|
| 567 |
+
if negative and attention_mask is not None: # set the masked tokens to the null embed
|
| 568 |
+
prompt_embeds = torch.where(attention_mask.to(torch.bool).unsqueeze(2), prompt_embeds, 0.0)
|
| 569 |
+
|
| 570 |
+
prompt_embeds = self.model.projection_model(text_hidden_states=prompt_embeds).text_hidden_states
|
| 571 |
+
|
| 572 |
+
if attention_mask is None:
|
| 573 |
+
raise gr.Error("Shouldn't reach here. Please raise an issue if you do.")
|
| 574 |
+
"""prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
| 575 |
+
if attention_mask is not None and negative_attention_mask is None:
|
| 576 |
+
negative_attention_mask = torch.ones_like(attention_mask)
|
| 577 |
+
elif attention_mask is None and negative_attention_mask is not None:
|
| 578 |
+
attention_mask = torch.ones_like(negative_attention_mask)"""
|
| 579 |
+
|
| 580 |
+
if prompts == [""]: # empty
|
| 581 |
+
return torch.zeros_like(prompt_embeds, device=prompt_embeds.device), None, None
|
| 582 |
+
|
| 583 |
+
prompt_embeds = prompt_embeds * attention_mask.unsqueeze(-1).to(prompt_embeds.dtype)
|
| 584 |
+
prompt_embeds = prompt_embeds * attention_mask.unsqueeze(-1).to(prompt_embeds.dtype)
|
| 585 |
+
return prompt_embeds, None, attention_mask
|
| 586 |
+
|
| 587 |
+
def get_fn_STFT(self) -> torch.nn.Module:
|
| 588 |
+
from audioldm.audio import TacotronSTFT
|
| 589 |
+
return TacotronSTFT(
|
| 590 |
+
filter_length=1024,
|
| 591 |
+
hop_length=160,
|
| 592 |
+
win_length=1024,
|
| 593 |
+
n_mel_channels=64,
|
| 594 |
+
sampling_rate=44100,
|
| 595 |
+
mel_fmin=0,
|
| 596 |
+
mel_fmax=22050,
|
| 597 |
+
)
|
| 598 |
+
|
| 599 |
+
def vae_encode(self, x: torch.Tensor) -> torch.Tensor:
|
| 600 |
+
x = x.unsqueeze(0)
|
| 601 |
+
|
| 602 |
+
audio_vae_length = int(self.model.transformer.config.sample_size * self.model.vae.hop_length)
|
| 603 |
+
audio_shape = (1, self.model.vae.config.audio_channels, audio_vae_length)
|
| 604 |
+
|
| 605 |
+
# check num_channels
|
| 606 |
+
if x.shape[1] == 1 and self.model.vae.config.audio_channels == 2:
|
| 607 |
+
x = x.repeat(1, 2, 1)
|
| 608 |
+
|
| 609 |
+
audio_length = x.shape[-1]
|
| 610 |
+
audio = x.new_zeros(audio_shape)
|
| 611 |
+
audio[:, :, : min(audio_length, audio_vae_length)] = x[:, :, :audio_vae_length]
|
| 612 |
+
|
| 613 |
+
encoded_audio = self.model.vae.encode(audio.to(self.device)).latent_dist
|
| 614 |
+
encoded_audio = encoded_audio.sample()
|
| 615 |
+
return encoded_audio
|
| 616 |
+
|
| 617 |
+
def vae_decode(self, x: torch.Tensor) -> torch.Tensor:
|
| 618 |
+
torch.cuda.empty_cache()
|
| 619 |
+
# return self.model.vae.decode(1 / self.model.vae.config.scaling_factor * x).sample
|
| 620 |
+
aud = self.model.vae.decode(x).sample
|
| 621 |
+
return aud[:, :, self.waveform_start:self.waveform_end]
|
| 622 |
+
|
| 623 |
+
def setup_extra_inputs(self, x: torch.Tensor, init_timestep: torch.Tensor,
|
| 624 |
+
extra_info: Optional[Any] = None,
|
| 625 |
+
audio_start_in_s: float = 0, audio_end_in_s: Optional[float] = None,
|
| 626 |
+
save_compute: bool = False) -> None:
|
| 627 |
+
max_audio_length_in_s = self.model.transformer.config.sample_size * self.model.vae.hop_length / \
|
| 628 |
+
self.model.vae.config.sampling_rate
|
| 629 |
+
if audio_end_in_s is None:
|
| 630 |
+
audio_end_in_s = max_audio_length_in_s
|
| 631 |
+
|
| 632 |
+
if audio_end_in_s - audio_start_in_s > max_audio_length_in_s:
|
| 633 |
+
raise ValueError(
|
| 634 |
+
f"The total audio length requested ({audio_end_in_s-audio_start_in_s}s) is longer "
|
| 635 |
+
f"than the model maximum possible length ({max_audio_length_in_s}). "
|
| 636 |
+
f"Make sure that 'audio_end_in_s-audio_start_in_s<={max_audio_length_in_s}'."
|
| 637 |
+
)
|
| 638 |
+
|
| 639 |
+
self.waveform_start = int(audio_start_in_s * self.model.vae.config.sampling_rate)
|
| 640 |
+
self.waveform_end = int(audio_end_in_s * self.model.vae.config.sampling_rate)
|
| 641 |
+
|
| 642 |
+
self.seconds_start_hidden_states, self.seconds_end_hidden_states = self.model.encode_duration(
|
| 643 |
+
audio_start_in_s, audio_end_in_s, self.device, False, 1)
|
| 644 |
+
|
| 645 |
+
if save_compute:
|
| 646 |
+
self.seconds_start_hidden_states = torch.cat([self.seconds_start_hidden_states, self.seconds_start_hidden_states], dim=0)
|
| 647 |
+
self.seconds_end_hidden_states = torch.cat([self.seconds_end_hidden_states, self.seconds_end_hidden_states], dim=0)
|
| 648 |
+
|
| 649 |
+
self.audio_duration_embeds = torch.cat([self.seconds_start_hidden_states,
|
| 650 |
+
self.seconds_end_hidden_states], dim=2)
|
| 651 |
+
|
| 652 |
+
# 7. Prepare rotary positional embedding
|
| 653 |
+
self.rotary_embedding = get_1d_rotary_pos_embed(
|
| 654 |
+
self.model.rotary_embed_dim,
|
| 655 |
+
x.shape[2] + self.audio_duration_embeds.shape[1],
|
| 656 |
+
use_real=True,
|
| 657 |
+
repeat_interleave_real=False,
|
| 658 |
+
)
|
| 659 |
+
|
| 660 |
+
self.model.scheduler._init_step_index(init_timestep)
|
| 661 |
+
|
| 662 |
+
# fix lower_order_nums for the reverse step - Option 1: only start from first order
|
| 663 |
+
# self.model.scheduler.lower_order_nums = 0
|
| 664 |
+
# self.model.scheduler.model_outputs = [None] * self.model.scheduler.config.solver_order
|
| 665 |
+
# fix lower_order_nums for the reverse step - Option 2: start from the correct order with history
|
| 666 |
+
t_to_idx = {float(v): k for k, v in enumerate(self.model.scheduler.timesteps)}
|
| 667 |
+
idx = len(self.model.scheduler.timesteps) - t_to_idx[float(init_timestep)] - 1
|
| 668 |
+
self.model.scheduler.model_outputs = [None, extra_info[idx] if extra_info is not None else None]
|
| 669 |
+
self.model.scheduler.lower_order_nums = min(self.model.scheduler.step_index,
|
| 670 |
+
self.model.scheduler.config.solver_order)
|
| 671 |
+
|
| 672 |
+
# if rand check:
|
| 673 |
+
# x *= self.model.scheduler.init_noise_sigma
|
| 674 |
+
# return x
|
| 675 |
+
|
| 676 |
+
def sample_xts_from_x0(self, x0: torch.Tensor, num_inference_steps: int = 50) -> torch.Tensor:
|
| 677 |
+
"""
|
| 678 |
+
Samples from P(x_1:T|x_0)
|
| 679 |
+
"""
|
| 680 |
+
|
| 681 |
+
sigmas = self.model.scheduler.sigmas
|
| 682 |
+
shapes = self.get_noise_shape(x0, num_inference_steps + 1)
|
| 683 |
+
xts = torch.zeros(shapes).to(x0.device)
|
| 684 |
+
xts[0] = x0
|
| 685 |
+
|
| 686 |
+
timesteps = self.model.scheduler.timesteps.to(self.device)
|
| 687 |
+
t_to_idx = {float(v): k for k, v in enumerate(timesteps)}
|
| 688 |
+
for t in reversed(timesteps):
|
| 689 |
+
# idx = t_to_idx[int(t)]
|
| 690 |
+
idx = num_inference_steps - t_to_idx[float(t)]
|
| 691 |
+
n = torch.randn_like(x0)
|
| 692 |
+
xts[idx] = x0 + n * sigmas[t_to_idx[float(t)]]
|
| 693 |
+
return xts
|
| 694 |
+
|
| 695 |
+
def get_zs_from_xts(self, xt: torch.Tensor, xtm1: torch.Tensor, data_pred: torch.Tensor,
|
| 696 |
+
t: torch.Tensor, numerical_fix: bool = True, first_order: bool = False, **kwargs
|
| 697 |
+
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
| 698 |
+
# pred of x0
|
| 699 |
+
sigmas = self.model.scheduler.sigmas
|
| 700 |
+
timesteps = self.model.scheduler.timesteps
|
| 701 |
+
solver_order = self.model.scheduler.config.solver_order
|
| 702 |
+
|
| 703 |
+
if self.model.scheduler.step_index is None:
|
| 704 |
+
self.model.scheduler._init_step_index(t)
|
| 705 |
+
curr_step_index = self.model.scheduler.step_index
|
| 706 |
+
|
| 707 |
+
# Improve numerical stability for small number of steps
|
| 708 |
+
lower_order_final = (curr_step_index == len(timesteps) - 1) and (
|
| 709 |
+
self.model.scheduler.config.euler_at_final
|
| 710 |
+
or (self.model.scheduler.config.lower_order_final and len(timesteps) < 15)
|
| 711 |
+
or self.model.scheduler.config.final_sigmas_type == "zero")
|
| 712 |
+
lower_order_second = ((curr_step_index == len(timesteps) - 2) and
|
| 713 |
+
self.model.scheduler.config.lower_order_final and len(timesteps) < 15)
|
| 714 |
+
|
| 715 |
+
data_pred = self.model.scheduler.convert_model_output(data_pred, sample=xt)
|
| 716 |
+
for i in range(solver_order - 1):
|
| 717 |
+
self.model.scheduler.model_outputs[i] = self.model.scheduler.model_outputs[i + 1]
|
| 718 |
+
self.model.scheduler.model_outputs[-1] = data_pred
|
| 719 |
+
|
| 720 |
+
# instead of brownian noise, here we calculate the noise ourselves
|
| 721 |
+
if (curr_step_index == len(timesteps) - 1) and self.model.scheduler.config.final_sigmas_type == "zero":
|
| 722 |
+
z = torch.zeros_like(xt)
|
| 723 |
+
elif first_order or solver_order == 1 or self.model.scheduler.lower_order_nums < 1 or lower_order_final:
|
| 724 |
+
sigma_t, sigma_s = sigmas[curr_step_index + 1], sigmas[curr_step_index]
|
| 725 |
+
h = torch.log(sigma_s) - torch.log(sigma_t)
|
| 726 |
+
z = (xtm1 - (sigma_t / sigma_s * torch.exp(-h)) * xt - (1 - torch.exp(-2.0 * h)) * data_pred) \
|
| 727 |
+
/ (sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)))
|
| 728 |
+
elif solver_order == 2 or self.model.scheduler.lower_order_nums < 2 or lower_order_second:
|
| 729 |
+
sigma_t = sigmas[curr_step_index + 1]
|
| 730 |
+
sigma_s0 = sigmas[curr_step_index]
|
| 731 |
+
sigma_s1 = sigmas[curr_step_index - 1]
|
| 732 |
+
m0, m1 = self.model.scheduler.model_outputs[-1], self.model.scheduler.model_outputs[-2]
|
| 733 |
+
h, h_0 = torch.log(sigma_s0) - torch.log(sigma_t), torch.log(sigma_s1) - torch.log(sigma_s0)
|
| 734 |
+
r0 = h_0 / h
|
| 735 |
+
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
|
| 736 |
+
|
| 737 |
+
# sde-dpmsolver++
|
| 738 |
+
z = (xtm1 - (sigma_t / sigma_s0 * torch.exp(-h)) * xt
|
| 739 |
+
- (1 - torch.exp(-2.0 * h)) * D0
|
| 740 |
+
- 0.5 * (1 - torch.exp(-2.0 * h)) * D1) \
|
| 741 |
+
/ (sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)))
|
| 742 |
+
|
| 743 |
+
# correction to avoid error accumulation
|
| 744 |
+
if numerical_fix:
|
| 745 |
+
if first_order or solver_order == 1 or self.model.scheduler.lower_order_nums < 1 or lower_order_final:
|
| 746 |
+
xtm1 = self.model.scheduler.dpm_solver_first_order_update(data_pred, sample=xt, noise=z)
|
| 747 |
+
elif solver_order == 2 or self.model.scheduler.lower_order_nums < 2 or lower_order_second:
|
| 748 |
+
xtm1 = self.model.scheduler.multistep_dpm_solver_second_order_update(
|
| 749 |
+
self.model.scheduler.model_outputs, sample=xt, noise=z)
|
| 750 |
+
# If not perfect recon - maybe TODO fix self.model.scheduler.model_outputs as well?
|
| 751 |
+
|
| 752 |
+
if self.model.scheduler.lower_order_nums < solver_order:
|
| 753 |
+
self.model.scheduler.lower_order_nums += 1
|
| 754 |
+
# upon completion increase step index by one
|
| 755 |
+
self.model.scheduler._step_index += 1
|
| 756 |
+
|
| 757 |
+
return z, xtm1, self.model.scheduler.model_outputs[-2]
|
| 758 |
+
|
| 759 |
+
def get_sr(self) -> int:
|
| 760 |
+
return self.model.vae.config.sampling_rate
|
| 761 |
+
|
| 762 |
+
def get_noise_shape(self, x0: torch.Tensor, num_steps: int) -> Tuple[int, int, int]:
|
| 763 |
+
variance_noise_shape = (num_steps,
|
| 764 |
+
self.model.transformer.config.in_channels,
|
| 765 |
+
int(self.model.transformer.config.sample_size))
|
| 766 |
+
return variance_noise_shape
|
| 767 |
+
|
| 768 |
+
def reverse_step_with_custom_noise(self, model_output: torch.Tensor, timestep: torch.Tensor, sample: torch.Tensor,
|
| 769 |
+
variance_noise: Optional[torch.Tensor] = None,
|
| 770 |
+
first_order: bool = False, **kwargs
|
| 771 |
+
) -> torch.Tensor:
|
| 772 |
+
if self.model.scheduler.step_index is None:
|
| 773 |
+
self.model.scheduler._init_step_index(timestep)
|
| 774 |
+
|
| 775 |
+
# Improve numerical stability for small number of steps
|
| 776 |
+
lower_order_final = (self.model.scheduler.step_index == len(self.model.scheduler.timesteps) - 1) and (
|
| 777 |
+
self.model.scheduler.config.euler_at_final
|
| 778 |
+
or (self.model.scheduler.config.lower_order_final and len(self.model.scheduler.timesteps) < 15)
|
| 779 |
+
or self.model.scheduler.config.final_sigmas_type == "zero"
|
| 780 |
+
)
|
| 781 |
+
lower_order_second = (
|
| 782 |
+
(self.model.scheduler.step_index == len(self.model.scheduler.timesteps) - 2) and
|
| 783 |
+
self.model.scheduler.config.lower_order_final and len(self.model.scheduler.timesteps) < 15
|
| 784 |
+
)
|
| 785 |
+
|
| 786 |
+
model_output = self.model.scheduler.convert_model_output(model_output, sample=sample)
|
| 787 |
+
for i in range(self.model.scheduler.config.solver_order - 1):
|
| 788 |
+
self.model.scheduler.model_outputs[i] = self.model.scheduler.model_outputs[i + 1]
|
| 789 |
+
self.model.scheduler.model_outputs[-1] = model_output
|
| 790 |
+
|
| 791 |
+
if variance_noise is None:
|
| 792 |
+
if self.model.scheduler.noise_sampler is None:
|
| 793 |
+
self.model.scheduler.noise_sampler = BrownianTreeNoiseSampler(
|
| 794 |
+
model_output, sigma_min=self.model.scheduler.config.sigma_min,
|
| 795 |
+
sigma_max=self.model.scheduler.config.sigma_max, seed=None)
|
| 796 |
+
variance_noise = self.model.scheduler.noise_sampler(
|
| 797 |
+
self.model.scheduler.sigmas[self.model.scheduler.step_index],
|
| 798 |
+
self.model.scheduler.sigmas[self.model.scheduler.step_index + 1]).to(model_output.device)
|
| 799 |
+
|
| 800 |
+
if first_order or self.model.scheduler.config.solver_order == 1 or \
|
| 801 |
+
self.model.scheduler.lower_order_nums < 1 or lower_order_final:
|
| 802 |
+
prev_sample = self.model.scheduler.dpm_solver_first_order_update(
|
| 803 |
+
model_output, sample=sample, noise=variance_noise)
|
| 804 |
+
elif self.model.scheduler.config.solver_order == 2 or \
|
| 805 |
+
self.model.scheduler.lower_order_nums < 2 or lower_order_second:
|
| 806 |
+
prev_sample = self.model.scheduler.multistep_dpm_solver_second_order_update(
|
| 807 |
+
self.model.scheduler.model_outputs, sample=sample, noise=variance_noise)
|
| 808 |
+
|
| 809 |
+
if self.model.scheduler.lower_order_nums < self.model.scheduler.config.solver_order:
|
| 810 |
+
self.model.scheduler.lower_order_nums += 1
|
| 811 |
+
|
| 812 |
+
# upon completion increase step index by one
|
| 813 |
+
self.model.scheduler._step_index += 1
|
| 814 |
+
|
| 815 |
+
return prev_sample
|
| 816 |
+
|
| 817 |
+
def unet_forward(self,
|
| 818 |
+
sample: torch.FloatTensor,
|
| 819 |
+
timestep: Union[torch.Tensor, float, int],
|
| 820 |
+
encoder_hidden_states: torch.Tensor,
|
| 821 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 822 |
+
return_dict: bool = True,
|
| 823 |
+
**kwargs) -> Tuple:
|
| 824 |
+
|
| 825 |
+
# Create text_audio_duration_embeds and audio_duration_embeds
|
| 826 |
+
embeds = torch.cat([encoder_hidden_states, self.seconds_start_hidden_states, self.seconds_end_hidden_states],
|
| 827 |
+
dim=1)
|
| 828 |
+
if encoder_attention_mask is None:
|
| 829 |
+
# handle the batched case
|
| 830 |
+
if embeds.shape[0] > 1:
|
| 831 |
+
embeds[0] = torch.zeros_like(embeds[0], device=embeds.device)
|
| 832 |
+
else:
|
| 833 |
+
embeds = torch.zeros_like(embeds, device=embeds.device)
|
| 834 |
+
|
| 835 |
+
noise_pred = self.model.transformer(sample,
|
| 836 |
+
timestep.unsqueeze(0),
|
| 837 |
+
encoder_hidden_states=embeds,
|
| 838 |
+
global_hidden_states=self.audio_duration_embeds,
|
| 839 |
+
rotary_embedding=self.rotary_embedding)
|
| 840 |
+
|
| 841 |
+
if not return_dict:
|
| 842 |
+
return (noise_pred.sample,)
|
| 843 |
+
|
| 844 |
+
return noise_pred, None, None
|
| 845 |
|
| 846 |
|
| 847 |
+
def load_model(model_id: str, device: torch.device,
|
| 848 |
+
double_precision: bool = False, token: Optional[str] = None) -> PipelineWrapper:
|
| 849 |
+
if 'audioldm2' in model_id:
|
| 850 |
+
ldm_stable = AudioLDM2Wrapper(model_id=model_id, device=device, double_precision=double_precision, token=token)
|
| 851 |
+
elif 'stable-audio' in model_id:
|
| 852 |
+
ldm_stable = StableAudWrapper(model_id=model_id, device=device, double_precision=double_precision, token=token)
|
| 853 |
ldm_stable.load_scheduler()
|
| 854 |
torch.cuda.empty_cache()
|
| 855 |
return ldm_stable
|
requirements.txt
CHANGED
|
@@ -1,8 +1,9 @@
|
|
| 1 |
-
torch
|
| 2 |
-
numpy<2
|
| 3 |
torchaudio
|
| 4 |
diffusers
|
| 5 |
accelerate
|
|
|
|
| 6 |
transformers
|
| 7 |
tqdm
|
| 8 |
soundfile
|
|
|
|
| 1 |
+
torch>2.2.0
|
| 2 |
+
numpy<2.0.0
|
| 3 |
torchaudio
|
| 4 |
diffusers
|
| 5 |
accelerate
|
| 6 |
+
torchsde
|
| 7 |
transformers
|
| 8 |
tqdm
|
| 9 |
soundfile
|
utils.py
CHANGED
|
@@ -2,8 +2,11 @@ import numpy as np
|
|
| 2 |
import torch
|
| 3 |
from typing import Optional, List, Tuple, NamedTuple, Union
|
| 4 |
from models import PipelineWrapper
|
|
|
|
| 5 |
from audioldm.utils import get_duration
|
| 6 |
|
|
|
|
|
|
|
| 7 |
|
| 8 |
class PromptEmbeddings(NamedTuple):
|
| 9 |
embedding_hidden_states: torch.Tensor
|
|
@@ -11,26 +14,57 @@ class PromptEmbeddings(NamedTuple):
|
|
| 11 |
boolean_prompt_mask: torch.Tensor
|
| 12 |
|
| 13 |
|
| 14 |
-
def load_audio(audio_path: Union[str, np.array], fn_STFT, left: int = 0, right: int = 0,
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
|
|
|
|
|
|
| 19 |
|
| 20 |
-
|
|
|
|
|
|
|
| 21 |
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
|
| 36 |
def get_height_of_spectrogram(length: int, ldm_stable: PipelineWrapper) -> int:
|
|
|
|
| 2 |
import torch
|
| 3 |
from typing import Optional, List, Tuple, NamedTuple, Union
|
| 4 |
from models import PipelineWrapper
|
| 5 |
+
import torchaudio
|
| 6 |
from audioldm.utils import get_duration
|
| 7 |
|
| 8 |
+
MAX_DURATION = 30
|
| 9 |
+
|
| 10 |
|
| 11 |
class PromptEmbeddings(NamedTuple):
|
| 12 |
embedding_hidden_states: torch.Tensor
|
|
|
|
| 14 |
boolean_prompt_mask: torch.Tensor
|
| 15 |
|
| 16 |
|
| 17 |
+
def load_audio(audio_path: Union[str, np.array], fn_STFT, left: int = 0, right: int = 0,
|
| 18 |
+
device: Optional[torch.device] = None,
|
| 19 |
+
return_wav: bool = False, stft: bool = False, model_sr: Optional[int] = None) -> torch.Tensor:
|
| 20 |
+
if stft: # AudioLDM/tango loading to spectrogram
|
| 21 |
+
if type(audio_path) is str:
|
| 22 |
+
import audioldm
|
| 23 |
+
import audioldm.audio
|
| 24 |
|
| 25 |
+
duration = get_duration(audio_path)
|
| 26 |
+
if MAX_DURATION is not None:
|
| 27 |
+
duration = min(duration, MAX_DURATION)
|
| 28 |
|
| 29 |
+
mel, _, wav = audioldm.audio.wav_to_fbank(audio_path, target_length=int(duration * 102.4), fn_STFT=fn_STFT)
|
| 30 |
+
mel = mel.unsqueeze(0)
|
| 31 |
+
else:
|
| 32 |
+
mel = audio_path
|
| 33 |
|
| 34 |
+
c, h, w = mel.shape
|
| 35 |
+
left = min(left, w-1)
|
| 36 |
+
right = min(right, w - left - 1)
|
| 37 |
+
mel = mel[:, :, left:w-right]
|
| 38 |
+
mel = mel.unsqueeze(0).to(device)
|
| 39 |
|
| 40 |
+
if return_wav:
|
| 41 |
+
return mel, 16000, duration, wav
|
| 42 |
+
|
| 43 |
+
return mel, model_sr, duration
|
| 44 |
+
else:
|
| 45 |
+
waveform, sr = torchaudio.load(audio_path)
|
| 46 |
+
if sr != model_sr:
|
| 47 |
+
waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=model_sr)
|
| 48 |
+
# waveform = waveform.numpy()[0, ...]
|
| 49 |
+
|
| 50 |
+
def normalize_wav(waveform):
|
| 51 |
+
waveform = waveform - torch.mean(waveform)
|
| 52 |
+
waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-8)
|
| 53 |
+
return waveform * 0.5
|
| 54 |
+
|
| 55 |
+
waveform = normalize_wav(waveform)
|
| 56 |
+
# waveform = waveform[None, ...]
|
| 57 |
+
# waveform = pad_wav(waveform, segment_length)
|
| 58 |
+
|
| 59 |
+
# waveform = waveform[0, ...]
|
| 60 |
+
waveform = torch.FloatTensor(waveform)
|
| 61 |
+
if MAX_DURATION is not None:
|
| 62 |
+
duration = min(waveform.shape[-1] / model_sr, MAX_DURATION)
|
| 63 |
+
waveform = waveform[:, :int(duration * model_sr)]
|
| 64 |
+
|
| 65 |
+
# cut waveform
|
| 66 |
+
duration = waveform.shape[-1] / model_sr
|
| 67 |
+
return waveform, model_sr, duration
|
| 68 |
|
| 69 |
|
| 70 |
def get_height_of_spectrogram(length: int, ldm_stable: PipelineWrapper) -> int:
|