Spaces:
Running
Running
Upload 116 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +47 -43
- .gitignore +28 -0
- README.md +12 -12
- app_v1v2.py +175 -0
- configs/astral_quantization/default_2048.yml +40 -0
- configs/astral_quantization/default_32.yml +40 -0
- configs/config.json +1 -0
- configs/inuse/.gitignore +0 -0
- configs/inuse/config.json +1 -0
- configs/presets/config_dit_mel_seed_uvit_whisper_base_f0_44k.yml +98 -0
- configs/presets/config_dit_mel_seed_uvit_whisper_small_wavenet.yml +91 -0
- configs/presets/config_dit_mel_seed_uvit_xlsr_tiny.yml +82 -0
- configs/v2/ar_base.yaml +0 -0
- configs/v2/dit_small.yaml +17 -0
- configs/v2/vc_wrapper.yaml +105 -0
- hf_utils.py +1 -1
- modules/__pycache__/audio.cpython-310.pyc +0 -0
- modules/__pycache__/commons.cpython-310.pyc +0 -0
- modules/__pycache__/commons.cpython-38.pyc +0 -0
- modules/__pycache__/diffusion_transformer.cpython-310.pyc +0 -0
- modules/__pycache__/flow_matching.cpython-310.pyc +0 -0
- modules/__pycache__/length_regulator.cpython-310.pyc +0 -0
- modules/__pycache__/rmvpe.cpython-310.pyc +0 -0
- modules/astral_quantization/__pycache__/bsq.cpython-310.pyc +0 -0
- modules/astral_quantization/__pycache__/convnext.cpython-310.pyc +0 -0
- modules/astral_quantization/__pycache__/default_model.cpython-310.pyc +0 -0
- modules/astral_quantization/bsq.py +569 -0
- modules/astral_quantization/convnext.py +209 -0
- modules/astral_quantization/default_model.py +73 -0
- modules/astral_quantization/transformer.py +254 -0
- modules/audio.py +82 -82
- modules/bigvgan/__pycache__/activations.cpython-310.pyc +0 -0
- modules/bigvgan/__pycache__/bigvgan.cpython-310.pyc +0 -0
- modules/bigvgan/__pycache__/env.cpython-310.pyc +0 -0
- modules/bigvgan/__pycache__/meldataset.cpython-310.pyc +0 -0
- modules/bigvgan/__pycache__/utils.cpython-310.pyc +0 -0
- modules/bigvgan/alias_free_activation/cuda/__pycache__/__init__.cpython-310.pyc +0 -0
- modules/bigvgan/alias_free_activation/cuda/__pycache__/activation1d.cpython-310.pyc +0 -0
- modules/bigvgan/alias_free_activation/cuda/__pycache__/load.cpython-310.pyc +0 -0
- modules/bigvgan/alias_free_activation/cuda/activation1d.py +2 -2
- modules/bigvgan/alias_free_activation/cuda/build/.ninja_deps +3 -0
- modules/bigvgan/alias_free_activation/cuda/build/.ninja_log +7 -0
- modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation.o +3 -0
- modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.cuda.o +3 -0
- modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.exp +0 -0
- modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.lib +0 -0
- modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.pyd +3 -0
- modules/bigvgan/alias_free_activation/cuda/build/build.ninja +38 -0
- modules/bigvgan/alias_free_activation/torch/__pycache__/__init__.cpython-310.pyc +0 -0
- modules/bigvgan/alias_free_activation/torch/__pycache__/act.cpython-310.pyc +0 -0
.gitattributes
CHANGED
|
@@ -1,43 +1,47 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
-
examples/reference/dingzhen_0.wav filter=lfs diff=lfs merge=lfs -text
|
| 37 |
-
examples/reference/s3p2.wav filter=lfs diff=lfs merge=lfs -text
|
| 38 |
-
examples/source/source_s3.wav filter=lfs diff=lfs merge=lfs -text
|
| 39 |
-
examples/source/source_s4.wav filter=lfs diff=lfs merge=lfs -text
|
| 40 |
-
examples/source/Wiz[[:space:]]Khalifa,Charlie[[:space:]]Puth[[:space:]]-[[:space:]]See[[:space:]]You[[:space:]]Again[[:space:]]\[vocals\]_\[cut_28sec\].wav filter=lfs diff=lfs merge=lfs -text
|
| 41 |
-
examples/reference/trump_0.wav filter=lfs diff=lfs merge=lfs -text
|
| 42 |
-
examples/source/jay_0.wav filter=lfs diff=lfs merge=lfs -text
|
| 43 |
-
examples/source/TECHNOPOLIS[[:space:]]-[[:space:]]2085[[:space:]]\[vocals\]_\[cut_14sec\].wav filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
examples/reference/dingzhen_0.wav filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
examples/reference/s3p2.wav filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
examples/source/source_s3.wav filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
examples/source/source_s4.wav filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
examples/source/Wiz[[:space:]]Khalifa,Charlie[[:space:]]Puth[[:space:]]-[[:space:]]See[[:space:]]You[[:space:]]Again[[:space:]]\[vocals\]_\[cut_28sec\].wav filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
examples/reference/trump_0.wav filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
examples/source/jay_0.wav filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
examples/source/TECHNOPOLIS[[:space:]]-[[:space:]]2085[[:space:]]\[vocals\]_\[cut_14sec\].wav filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
modules/bigvgan/alias_free_activation/cuda/build/.ninja_deps filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.cuda.o filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.pyd filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation.o filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# general things to ignore
|
| 2 |
+
.DS_Store
|
| 3 |
+
build/
|
| 4 |
+
build_contrib/
|
| 5 |
+
dist/
|
| 6 |
+
.cache/
|
| 7 |
+
*.egg-info/
|
| 8 |
+
*.egg
|
| 9 |
+
*.py[cod]
|
| 10 |
+
__pycache__/
|
| 11 |
+
*.so
|
| 12 |
+
*~
|
| 13 |
+
|
| 14 |
+
# IDE
|
| 15 |
+
.vscode/
|
| 16 |
+
.idea/
|
| 17 |
+
|
| 18 |
+
# misc
|
| 19 |
+
checkpoints/
|
| 20 |
+
test_waves/
|
| 21 |
+
reconstructed/
|
| 22 |
+
.python-version
|
| 23 |
+
ruff.log
|
| 24 |
+
/configs/inuse/
|
| 25 |
+
runs/
|
| 26 |
+
/garbages/
|
| 27 |
+
/flagged/
|
| 28 |
+
/experimental/
|
README.md
CHANGED
|
@@ -1,13 +1,13 @@
|
|
| 1 |
-
---
|
| 2 |
-
title: Seed Voice Conversion
|
| 3 |
-
emoji: 🎤🔄
|
| 4 |
-
colorFrom: green
|
| 5 |
-
colorTo: green
|
| 6 |
-
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
-
app_file:
|
| 9 |
-
pinned: false
|
| 10 |
-
license: gpl-3.0
|
| 11 |
-
---
|
| 12 |
-
|
| 13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Seed Voice Conversion
|
| 3 |
+
emoji: 🎤🔄
|
| 4 |
+
colorFrom: green
|
| 5 |
+
colorTo: green
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 5.23.0
|
| 8 |
+
app_file: app_v1v2.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: gpl-3.0
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app_v1v2.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import spaces
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import torch
|
| 4 |
+
import yaml
|
| 5 |
+
import argparse
|
| 6 |
+
from seed_vc_wrapper import SeedVCWrapper
|
| 7 |
+
|
| 8 |
+
# Set up device and torch configurations
|
| 9 |
+
if torch.cuda.is_available():
|
| 10 |
+
device = torch.device("cuda")
|
| 11 |
+
elif torch.backends.mps.is_available():
|
| 12 |
+
device = torch.device("mps")
|
| 13 |
+
else:
|
| 14 |
+
device = torch.device("cpu")
|
| 15 |
+
|
| 16 |
+
torch._inductor.config.coordinate_descent_tuning = True
|
| 17 |
+
torch._inductor.config.triton.unique_kernel_names = True
|
| 18 |
+
|
| 19 |
+
if hasattr(torch._inductor.config, "fx_graph_cache"):
|
| 20 |
+
# Experimental feature to reduce compilation times, will be on by default in future
|
| 21 |
+
torch._inductor.config.fx_graph_cache = True
|
| 22 |
+
|
| 23 |
+
dtype = torch.float16
|
| 24 |
+
|
| 25 |
+
def load_v2_models(args):
|
| 26 |
+
from hydra.utils import instantiate
|
| 27 |
+
from omegaconf import DictConfig
|
| 28 |
+
cfg = DictConfig(yaml.safe_load(open("configs/v2/vc_wrapper.yaml", "r")))
|
| 29 |
+
vc_wrapper = instantiate(cfg)
|
| 30 |
+
vc_wrapper.load_checkpoints()
|
| 31 |
+
vc_wrapper.to(device)
|
| 32 |
+
vc_wrapper.eval()
|
| 33 |
+
|
| 34 |
+
vc_wrapper.setup_ar_caches(max_batch_size=1, max_seq_len=4096, dtype=dtype, device=device)
|
| 35 |
+
|
| 36 |
+
if args.compile:
|
| 37 |
+
vc_wrapper.compile_ar()
|
| 38 |
+
# vc_wrapper.compile_cfm()
|
| 39 |
+
|
| 40 |
+
return vc_wrapper
|
| 41 |
+
|
| 42 |
+
def create_v1_interface():
|
| 43 |
+
# Initialize the V1 wrapper
|
| 44 |
+
vc_wrapper = SeedVCWrapper()
|
| 45 |
+
|
| 46 |
+
# Set up Gradio interface
|
| 47 |
+
description = ("Zero-shot voice conversion with in-context learning. For local deployment please check [GitHub repository](https://github.com/Plachtaa/seed-vc) "
|
| 48 |
+
"for details and updates.<br>Note that any reference audio will be forcefully clipped to 25s if beyond this length.<br> "
|
| 49 |
+
"If total duration of source and reference audio exceeds 30s, source audio will be processed in chunks.<br> "
|
| 50 |
+
"无需训练的 zero-shot 语音/歌声转换模型,若需本地部署查看[GitHub页面](https://github.com/Plachtaa/seed-vc)<br>"
|
| 51 |
+
"请注意,参考音频若超过 25 秒,则会被自动裁剪至此长度。<br>若源音频和参考音频的总时长超过 30 秒,源音频将被分段处理。")
|
| 52 |
+
|
| 53 |
+
inputs = [
|
| 54 |
+
gr.Audio(type="filepath", label="Source Audio / 源音频"),
|
| 55 |
+
gr.Audio(type="filepath", label="Reference Audio / 参考音频"),
|
| 56 |
+
gr.Slider(minimum=1, maximum=200, value=10, step=1, label="Diffusion Steps / 扩散步数",
|
| 57 |
+
info="10 by default, 50~100 for best quality / 默认为 10,50~100 为最佳质量"),
|
| 58 |
+
gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="Length Adjust / 长度调整",
|
| 59 |
+
info="<1.0 for speed-up speech, >1.0 for slow-down speech / <1.0 加速语速,>1.0 减慢语速"),
|
| 60 |
+
gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.7, label="Inference CFG Rate",
|
| 61 |
+
info="has subtle influence / 有微小影响"),
|
| 62 |
+
gr.Checkbox(label="Use F0 conditioned model / 启用F0输入", value=False,
|
| 63 |
+
info="Must set to true for singing voice conversion / 歌声转换时必须勾选"),
|
| 64 |
+
gr.Checkbox(label="Auto F0 adjust / 自动F0调整", value=True,
|
| 65 |
+
info="Roughly adjust F0 to match target voice. Only works when F0 conditioned model is used. / 粗略调整 F0 以匹配目标音色,仅在勾选 '启用F0输入' 时生效"),
|
| 66 |
+
gr.Slider(label='Pitch shift / 音调变换', minimum=-24, maximum=24, step=1, value=0,
|
| 67 |
+
info="Pitch shift in semitones, only works when F0 conditioned model is used / 半音数的音高变换,仅在勾选 '启用F0输入' 时生效"),
|
| 68 |
+
]
|
| 69 |
+
|
| 70 |
+
examples = [
|
| 71 |
+
["examples/source/yae_0.wav", "examples/reference/dingzhen_0.wav", 25, 1.0, 0.7, False, True, 0],
|
| 72 |
+
["examples/source/jay_0.wav", "examples/reference/azuma_0.wav", 25, 1.0, 0.7, True, True, 0],
|
| 73 |
+
["examples/source/Wiz Khalifa,Charlie Puth - See You Again [vocals]_[cut_28sec].wav",
|
| 74 |
+
"examples/reference/teio_0.wav", 100, 1.0, 0.7, True, False, 0],
|
| 75 |
+
["examples/source/TECHNOPOLIS - 2085 [vocals]_[cut_14sec].wav",
|
| 76 |
+
"examples/reference/trump_0.wav", 50, 1.0, 0.7, True, False, -12],
|
| 77 |
+
]
|
| 78 |
+
|
| 79 |
+
outputs = [
|
| 80 |
+
gr.Audio(label="Stream Output Audio / 流式输出", streaming=True, format='mp3'),
|
| 81 |
+
gr.Audio(label="Full Output Audio / 完整输出", streaming=False, format='wav')
|
| 82 |
+
]
|
| 83 |
+
|
| 84 |
+
return gr.Interface(
|
| 85 |
+
fn=vc_wrapper.convert_voice,
|
| 86 |
+
description=description,
|
| 87 |
+
inputs=inputs,
|
| 88 |
+
outputs=outputs,
|
| 89 |
+
title="Seed Voice Conversion V1 (Voice & Singing Voice Conversion)",
|
| 90 |
+
examples=examples,
|
| 91 |
+
cache_examples=False,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
def create_v2_interface(vc_wrapper):
|
| 95 |
+
# Set up Gradio interface
|
| 96 |
+
description = ("Zero-shot voice/style conversion with in-context learning. For local deployment please check [GitHub repository](https://github.com/Plachtaa/seed-vc) "
|
| 97 |
+
"for details and updates.<br>Note that any reference audio will be forcefully clipped to 25s if beyond this length.<br> "
|
| 98 |
+
"If total duration of source and reference audio exceeds 30s, source audio will be processed in chunks.<br> "
|
| 99 |
+
"Please click the 'convert style/emotion/accent' checkbox to convert the style, emotion, or accent of the source audio, or else only timbre conversion will be performed.<br> "
|
| 100 |
+
"Click the 'anonymization only' checkbox will ignore reference audio but convert source to an 'average voice' determined by model itself.<br> "
|
| 101 |
+
"无需训练的 zero-shot 语音/口音转换模型,若需本地部署查看[GitHub页面](https://github.com/Plachtaa/seed-vc)<br>"
|
| 102 |
+
"请注意,参考音频若超过 25 秒,则会被自动裁剪至此长度。<br>若源音频和参考音频的总时长超过 30 秒,源音频将被分段处理。"
|
| 103 |
+
"<br>请勾选 'convert style/emotion/accent' 以转换源音频的风格、情感或口音,否则仅执行音色转换。<br>"
|
| 104 |
+
"勾选 'anonymization only' 会无视参考音频而将源音频转换为某种由模型自身决定的 '平均音色'。<br>"
|
| 105 |
+
|
| 106 |
+
"Credits to [Vevo](https://github.com/open-mmlab/Amphion/tree/main/models/vc/vevo)"
|
| 107 |
+
)
|
| 108 |
+
inputs = [
|
| 109 |
+
gr.Audio(type="filepath", label="Source Audio / 源音频"),
|
| 110 |
+
gr.Audio(type="filepath", label="Reference Audio / 参考音频"),
|
| 111 |
+
gr.Slider(minimum=1, maximum=200, value=30, step=1, label="Diffusion Steps / 扩散步数",
|
| 112 |
+
info="30 by default, 50~100 for best quality / 默认为 30,50~100 为最佳质量"),
|
| 113 |
+
gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="Length Adjust / 长度调整",
|
| 114 |
+
info="<1.0 for speed-up speech, >1.0 for slow-down speech / <1.0 加速语速,>1.0 减慢语速"),
|
| 115 |
+
gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.0, label="Intelligibility CFG Rate",
|
| 116 |
+
info="controls pronunciation intelligibility / 控制发音清晰度"),
|
| 117 |
+
gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.7, label="Similarity CFG Rate",
|
| 118 |
+
info="controls similarity to reference audio / 控制与参考音频的相似度"),
|
| 119 |
+
gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.9, label="Top-p",
|
| 120 |
+
info="AR model sampling top P"),
|
| 121 |
+
gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=1.0, label="Temperature",
|
| 122 |
+
info="AR model sampling temperature"),
|
| 123 |
+
gr.Slider(minimum=1.0, maximum=3.0, step=0.1, value=1.0, label="Repetition Penalty",
|
| 124 |
+
info="AR model sampling repetition penalty"),
|
| 125 |
+
gr.Checkbox(label="convert style/emotion/accent", value=False),
|
| 126 |
+
gr.Checkbox(label="anonymization only", value=False),
|
| 127 |
+
]
|
| 128 |
+
|
| 129 |
+
examples = [
|
| 130 |
+
["examples/source/yae_0.wav", "examples/reference/dingzhen_0.wav", 50, 1.0, 0.0, 0.7, 0.9, 1.0, 1.0, False, False],
|
| 131 |
+
["examples/source/jay_0.wav", "examples/reference/azuma_0.wav", 50, 1.0, 0.0, 0.7, 0.9, 1.0, 1.0, False, False],
|
| 132 |
+
]
|
| 133 |
+
|
| 134 |
+
outputs = [
|
| 135 |
+
gr.Audio(label="Stream Output Audio / 流式输出", streaming=True, format='mp3'),
|
| 136 |
+
gr.Audio(label="Full Output Audio / 完整输出", streaming=False, format='wav')
|
| 137 |
+
]
|
| 138 |
+
|
| 139 |
+
return gr.Interface(
|
| 140 |
+
fn=vc_wrapper.convert_voice_with_streaming,
|
| 141 |
+
description=description,
|
| 142 |
+
inputs=inputs,
|
| 143 |
+
outputs=outputs,
|
| 144 |
+
title="Seed Voice Conversion V2 (Voice & Style Conversion)",
|
| 145 |
+
examples=examples,
|
| 146 |
+
cache_examples=False,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
def main(args):
|
| 150 |
+
# Load V2 models
|
| 151 |
+
vc_wrapper_v2 = load_v2_models(args)
|
| 152 |
+
|
| 153 |
+
# Create interfaces
|
| 154 |
+
v1_interface = create_v1_interface()
|
| 155 |
+
v2_interface = create_v2_interface(vc_wrapper_v2)
|
| 156 |
+
|
| 157 |
+
# Create tabs
|
| 158 |
+
with gr.Blocks(title="Seed Voice Conversion") as demo:
|
| 159 |
+
gr.Markdown("# Seed Voice Conversion")
|
| 160 |
+
gr.Markdown("Choose between V1 (Voice & Singing Voice Conversion) or V2 (Voice & Style Conversion)")
|
| 161 |
+
|
| 162 |
+
with gr.Tabs():
|
| 163 |
+
with gr.TabItem("V2 - Voice & Style Conversion"):
|
| 164 |
+
v2_interface.render()
|
| 165 |
+
with gr.TabItem("V1 - Voice & Singing Voice Conversion"):
|
| 166 |
+
v1_interface.render()
|
| 167 |
+
|
| 168 |
+
# Launch the combined interface
|
| 169 |
+
demo.launch()
|
| 170 |
+
|
| 171 |
+
if __name__ == "__main__":
|
| 172 |
+
parser = argparse.ArgumentParser()
|
| 173 |
+
parser.add_argument("--compile", type=bool, default=True)
|
| 174 |
+
args = parser.parse_args()
|
| 175 |
+
main(args)
|
configs/astral_quantization/default_2048.yml
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: modules.astral_quantization.default_model.AstralQuantizer
|
| 2 |
+
tokenizer_name: "openai/whisper-small"
|
| 3 |
+
ssl_model_name: "facebook/hubert-large-ll60k"
|
| 4 |
+
ssl_output_layer: 18
|
| 5 |
+
encoder:
|
| 6 |
+
_target_: modules.astral_quantization.convnext.ConvNeXtV2Stage
|
| 7 |
+
dim: 512
|
| 8 |
+
num_blocks: 12
|
| 9 |
+
intermediate_dim: 1536
|
| 10 |
+
dilation: 1
|
| 11 |
+
input_dim: 1024
|
| 12 |
+
quantizer:
|
| 13 |
+
_target_: modules.astral_quantization.bsq.BinarySphericalQuantize
|
| 14 |
+
codebook_size: 2048 # codebook size, must be a power of 2
|
| 15 |
+
dim: 512
|
| 16 |
+
entropy_loss_weight: 0.1
|
| 17 |
+
diversity_gamma: 1.0
|
| 18 |
+
spherical: True
|
| 19 |
+
enable_entropy_loss: True
|
| 20 |
+
soft_entropy_loss: True
|
| 21 |
+
decoder:
|
| 22 |
+
_target_: modules.astral_quantization.convnext.ConvNeXtV2Stage
|
| 23 |
+
dim: 512
|
| 24 |
+
num_blocks: 12
|
| 25 |
+
intermediate_dim: 1536
|
| 26 |
+
dilation: 1
|
| 27 |
+
output_dim: 1024
|
| 28 |
+
gin_channels: 192
|
| 29 |
+
asr_decoder:
|
| 30 |
+
_target_: modules.astral_quantization.asr_decoder.ASRDecoder
|
| 31 |
+
hidden_dim: 768
|
| 32 |
+
num_heads: 12
|
| 33 |
+
depth: 12
|
| 34 |
+
block_size: 4096
|
| 35 |
+
in_channels: 512
|
| 36 |
+
n_vocab: 51866
|
| 37 |
+
bos_id: 50528
|
| 38 |
+
eos_id: 50527
|
| 39 |
+
dropout_rate: 0.0
|
| 40 |
+
attn_dropout_rate: 0.0
|
configs/astral_quantization/default_32.yml
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: default_model.AstralQuantizer
|
| 2 |
+
tokenizer_name: "openai/whisper-small"
|
| 3 |
+
ssl_model_name: "facebook/hubert-large-ll60k"
|
| 4 |
+
ssl_output_layer: 18
|
| 5 |
+
encoder:
|
| 6 |
+
_target_: modules.convnext.ConvNeXtV2Stage
|
| 7 |
+
dim: 512
|
| 8 |
+
num_blocks: 12
|
| 9 |
+
intermediate_dim: 1536
|
| 10 |
+
dilation: 1
|
| 11 |
+
input_dim: 1024
|
| 12 |
+
quantizer:
|
| 13 |
+
_target_: modules.bsq.BinarySphericalQuantize
|
| 14 |
+
codebook_size: 32 # codebook size, must be a power of 2
|
| 15 |
+
dim: 512
|
| 16 |
+
entropy_loss_weight: 0.1
|
| 17 |
+
diversity_gamma: 1.0
|
| 18 |
+
spherical: True
|
| 19 |
+
enable_entropy_loss: True
|
| 20 |
+
soft_entropy_loss: True
|
| 21 |
+
decoder:
|
| 22 |
+
_target_: modules.convnext.ConvNeXtV2Stage
|
| 23 |
+
dim: 512
|
| 24 |
+
num_blocks: 12
|
| 25 |
+
intermediate_dim: 1536
|
| 26 |
+
dilation: 1
|
| 27 |
+
output_dim: 1024
|
| 28 |
+
gin_channels: 192
|
| 29 |
+
asr_decoder:
|
| 30 |
+
_target_: modules.asr_decoder.ASRDecoder
|
| 31 |
+
hidden_dim: 768
|
| 32 |
+
num_heads: 12
|
| 33 |
+
depth: 12
|
| 34 |
+
block_size: 4096
|
| 35 |
+
in_channels: 512
|
| 36 |
+
n_vocab: 51866
|
| 37 |
+
bos_id: 50528
|
| 38 |
+
eos_id: 50527
|
| 39 |
+
dropout_rate: 0.0
|
| 40 |
+
attn_dropout_rate: 0.0
|
configs/config.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"reference_audio_path": "D:/FAcodec/test_waves/kobe_0.wav", "sg_hostapi": "MME", "sg_wasapi_exclusive": false, "sg_input_device": "\u9ea6\u514b\u98ce (Razer BlackShark V2 HS 2.4", "sg_output_device": "\u626c\u58f0\u5668 (Razer BlackShark V2 HS 2.4", "sr_type": "sr_model", "diffusion_steps": 10.0, "inference_cfg_rate": 0.0, "max_prompt_length": 3.0, "block_time": 0.7, "crossfade_length": 0.04, "extra_time": 0.5, "extra_time_right": 0.02}
|
configs/inuse/.gitignore
ADDED
|
File without changes
|
configs/inuse/config.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"reference_audio_path": "D:/seed-vc/examples/reference/trump_0.wav", "sg_hostapi": "MME", "sg_wasapi_exclusive": false, "sg_input_device": "\u9ea6\u514b\u98ce (Razer BlackShark V2 HS USB", "sg_output_device": "\u626c\u58f0\u5668 (Razer BlackShark V2 HS USB", "sr_type": "sr_model", "diffusion_steps": 8.0, "inference_cfg_rate": 0.7, "max_prompt_length": 3.0, "block_time": 0.58, "crossfade_length": 0.04, "extra_time_ce": 2.5, "extra_time": 0.5, "extra_time_right": 0.02}
|
configs/presets/config_dit_mel_seed_uvit_whisper_base_f0_44k.yml
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
log_dir: "./runs"
|
| 2 |
+
save_freq: 1
|
| 3 |
+
log_interval: 10
|
| 4 |
+
save_interval: 1000
|
| 5 |
+
device: "cuda"
|
| 6 |
+
epochs: 1000 # number of epochs for first stage training (pre-training)
|
| 7 |
+
batch_size: 1
|
| 8 |
+
batch_length: 100 # maximum duration of audio in a batch (in seconds)
|
| 9 |
+
max_len: 80 # maximum number of frames
|
| 10 |
+
pretrained_model: "DiT_seed_v2_uvit_whisper_base_f0_44k_bigvgan_pruned_ft_ema.pth"
|
| 11 |
+
pretrained_encoder: ""
|
| 12 |
+
load_only_params: False # set to true if do not want to load epoch numbers and optimizer parameters
|
| 13 |
+
|
| 14 |
+
preprocess_params:
|
| 15 |
+
sr: 44100
|
| 16 |
+
spect_params:
|
| 17 |
+
n_fft: 2048
|
| 18 |
+
win_length: 2048
|
| 19 |
+
hop_length: 512
|
| 20 |
+
n_mels: 128
|
| 21 |
+
fmin: 0
|
| 22 |
+
fmax: "None"
|
| 23 |
+
|
| 24 |
+
model_params:
|
| 25 |
+
dit_type: "DiT" # uDiT or DiT
|
| 26 |
+
reg_loss_type: "l1" # l1 or l2
|
| 27 |
+
|
| 28 |
+
timbre_shifter:
|
| 29 |
+
se_db_path: "./modules/openvoice/checkpoints_v2/converter/se_db.pt"
|
| 30 |
+
ckpt_path: './modules/openvoice/checkpoints_v2/converter'
|
| 31 |
+
|
| 32 |
+
vocoder:
|
| 33 |
+
type: "bigvgan"
|
| 34 |
+
name: "nvidia/bigvgan_v2_44khz_128band_512x"
|
| 35 |
+
|
| 36 |
+
speech_tokenizer:
|
| 37 |
+
type: 'whisper'
|
| 38 |
+
name: "openai/whisper-small"
|
| 39 |
+
|
| 40 |
+
style_encoder:
|
| 41 |
+
dim: 192
|
| 42 |
+
campplus_path: "campplus_cn_common.bin"
|
| 43 |
+
|
| 44 |
+
DAC:
|
| 45 |
+
encoder_dim: 64
|
| 46 |
+
encoder_rates: [2, 5, 5, 6]
|
| 47 |
+
decoder_dim: 1536
|
| 48 |
+
decoder_rates: [ 6, 5, 5, 2 ]
|
| 49 |
+
sr: 24000
|
| 50 |
+
|
| 51 |
+
length_regulator:
|
| 52 |
+
channels: 768
|
| 53 |
+
is_discrete: false
|
| 54 |
+
in_channels: 768
|
| 55 |
+
content_codebook_size: 2048
|
| 56 |
+
sampling_ratios: [1, 1, 1, 1]
|
| 57 |
+
vector_quantize: false
|
| 58 |
+
n_codebooks: 1
|
| 59 |
+
quantizer_dropout: 0.0
|
| 60 |
+
f0_condition: true
|
| 61 |
+
n_f0_bins: 256
|
| 62 |
+
|
| 63 |
+
DiT:
|
| 64 |
+
hidden_dim: 768
|
| 65 |
+
num_heads: 12
|
| 66 |
+
depth: 17
|
| 67 |
+
class_dropout_prob: 0.1
|
| 68 |
+
block_size: 8192
|
| 69 |
+
in_channels: 128
|
| 70 |
+
style_condition: true
|
| 71 |
+
final_layer_type: 'mlp'
|
| 72 |
+
target: 'mel' # mel or codec
|
| 73 |
+
content_dim: 768
|
| 74 |
+
content_codebook_size: 1024
|
| 75 |
+
content_type: 'discrete'
|
| 76 |
+
f0_condition: true
|
| 77 |
+
n_f0_bins: 256
|
| 78 |
+
content_codebooks: 1
|
| 79 |
+
is_causal: false
|
| 80 |
+
long_skip_connection: false
|
| 81 |
+
zero_prompt_speech_token: false # for prompt component, do not input corresponding speech token
|
| 82 |
+
time_as_token: false
|
| 83 |
+
style_as_token: false
|
| 84 |
+
uvit_skip_connection: true
|
| 85 |
+
add_resblock_in_transformer: false
|
| 86 |
+
|
| 87 |
+
wavenet:
|
| 88 |
+
hidden_dim: 768
|
| 89 |
+
num_layers: 8
|
| 90 |
+
kernel_size: 5
|
| 91 |
+
dilation_rate: 1
|
| 92 |
+
p_dropout: 0.2
|
| 93 |
+
style_condition: true
|
| 94 |
+
|
| 95 |
+
loss_params:
|
| 96 |
+
base_lr: 0.0001
|
| 97 |
+
lambda_mel: 45
|
| 98 |
+
lambda_kl: 1.0
|
configs/presets/config_dit_mel_seed_uvit_whisper_small_wavenet.yml
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
log_dir: "./runs"
|
| 2 |
+
save_freq: 1
|
| 3 |
+
log_interval: 10
|
| 4 |
+
save_interval: 1000
|
| 5 |
+
device: "cuda"
|
| 6 |
+
epochs: 1000 # number of epochs for first stage training (pre-training)
|
| 7 |
+
batch_size: 2
|
| 8 |
+
batch_length: 100 # maximum duration of audio in a batch (in seconds)
|
| 9 |
+
max_len: 80 # maximum number of frames
|
| 10 |
+
pretrained_model: "DiT_seed_v2_uvit_whisper_small_wavenet_bigvgan_pruned.pth"
|
| 11 |
+
pretrained_encoder: ""
|
| 12 |
+
load_only_params: False # set to true if do not want to load epoch numbers and optimizer parameters
|
| 13 |
+
|
| 14 |
+
preprocess_params:
|
| 15 |
+
sr: 22050
|
| 16 |
+
spect_params:
|
| 17 |
+
n_fft: 1024
|
| 18 |
+
win_length: 1024
|
| 19 |
+
hop_length: 256
|
| 20 |
+
n_mels: 80
|
| 21 |
+
fmin: 0
|
| 22 |
+
fmax: "None"
|
| 23 |
+
|
| 24 |
+
model_params:
|
| 25 |
+
dit_type: "DiT" # uDiT or DiT
|
| 26 |
+
reg_loss_type: "l1" # l1 or l2
|
| 27 |
+
|
| 28 |
+
timbre_shifter:
|
| 29 |
+
se_db_path: "./modules/openvoice/checkpoints_v2/converter/se_db.pt"
|
| 30 |
+
ckpt_path: './modules/openvoice/checkpoints_v2/converter'
|
| 31 |
+
|
| 32 |
+
speech_tokenizer:
|
| 33 |
+
type: 'whisper'
|
| 34 |
+
name: "openai/whisper-small"
|
| 35 |
+
|
| 36 |
+
style_encoder:
|
| 37 |
+
dim: 192
|
| 38 |
+
campplus_path: "campplus_cn_common.bin"
|
| 39 |
+
|
| 40 |
+
vocoder:
|
| 41 |
+
type: "bigvgan"
|
| 42 |
+
name: "nvidia/bigvgan_v2_22khz_80band_256x"
|
| 43 |
+
|
| 44 |
+
length_regulator:
|
| 45 |
+
channels: 512
|
| 46 |
+
is_discrete: false
|
| 47 |
+
in_channels: 768
|
| 48 |
+
content_codebook_size: 2048
|
| 49 |
+
sampling_ratios: [1, 1, 1, 1]
|
| 50 |
+
vector_quantize: false
|
| 51 |
+
n_codebooks: 1
|
| 52 |
+
quantizer_dropout: 0.0
|
| 53 |
+
f0_condition: false
|
| 54 |
+
n_f0_bins: 512
|
| 55 |
+
|
| 56 |
+
DiT:
|
| 57 |
+
hidden_dim: 512
|
| 58 |
+
num_heads: 8
|
| 59 |
+
depth: 13
|
| 60 |
+
class_dropout_prob: 0.1
|
| 61 |
+
block_size: 8192
|
| 62 |
+
in_channels: 80
|
| 63 |
+
style_condition: true
|
| 64 |
+
final_layer_type: 'wavenet'
|
| 65 |
+
target: 'mel' # mel or codec
|
| 66 |
+
content_dim: 512
|
| 67 |
+
content_codebook_size: 1024
|
| 68 |
+
content_type: 'discrete'
|
| 69 |
+
f0_condition: false
|
| 70 |
+
n_f0_bins: 512
|
| 71 |
+
content_codebooks: 1
|
| 72 |
+
is_causal: false
|
| 73 |
+
long_skip_connection: true
|
| 74 |
+
zero_prompt_speech_token: false # for prompt component, do not input corresponding speech token
|
| 75 |
+
time_as_token: false
|
| 76 |
+
style_as_token: false
|
| 77 |
+
uvit_skip_connection: true
|
| 78 |
+
add_resblock_in_transformer: false
|
| 79 |
+
|
| 80 |
+
wavenet:
|
| 81 |
+
hidden_dim: 512
|
| 82 |
+
num_layers: 8
|
| 83 |
+
kernel_size: 5
|
| 84 |
+
dilation_rate: 1
|
| 85 |
+
p_dropout: 0.2
|
| 86 |
+
style_condition: true
|
| 87 |
+
|
| 88 |
+
loss_params:
|
| 89 |
+
base_lr: 0.0001
|
| 90 |
+
lambda_mel: 45
|
| 91 |
+
lambda_kl: 1.0
|
configs/presets/config_dit_mel_seed_uvit_xlsr_tiny.yml
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
log_dir: "./runs/"
|
| 2 |
+
save_freq: 1
|
| 3 |
+
log_interval: 10
|
| 4 |
+
save_interval: 500
|
| 5 |
+
device: "cuda"
|
| 6 |
+
epochs: 1000 # number of epochs for first stage training (pre-training)
|
| 7 |
+
batch_size: 2
|
| 8 |
+
batch_length: 100 # maximum duration of audio in a batch (in seconds)
|
| 9 |
+
max_len: 80 # maximum number of frames
|
| 10 |
+
pretrained_model: "DiT_uvit_tat_xlsr_ema.pth"
|
| 11 |
+
pretrained_encoder: ""
|
| 12 |
+
load_only_params: False # set to true if do not want to load epoch numbers and optimizer parameters
|
| 13 |
+
|
| 14 |
+
preprocess_params:
|
| 15 |
+
sr: 22050
|
| 16 |
+
spect_params:
|
| 17 |
+
n_fft: 1024
|
| 18 |
+
win_length: 1024
|
| 19 |
+
hop_length: 256
|
| 20 |
+
n_mels: 80
|
| 21 |
+
fmin: 0
|
| 22 |
+
fmax: 8000
|
| 23 |
+
|
| 24 |
+
model_params:
|
| 25 |
+
dit_type: "DiT" # uDiT or DiT
|
| 26 |
+
reg_loss_type: "l1" # l1 or l2
|
| 27 |
+
diffusion_type: "flow"
|
| 28 |
+
|
| 29 |
+
timbre_shifter:
|
| 30 |
+
se_db_path: "./modules/openvoice/checkpoints_v2/converter/se_db.pt"
|
| 31 |
+
ckpt_path: './modules/openvoice/checkpoints_v2/converter'
|
| 32 |
+
|
| 33 |
+
vocoder:
|
| 34 |
+
type: "hifigan"
|
| 35 |
+
|
| 36 |
+
speech_tokenizer:
|
| 37 |
+
type: 'xlsr'
|
| 38 |
+
output_layer: 12
|
| 39 |
+
name: 'facebook/wav2vec2-xls-r-300m'
|
| 40 |
+
|
| 41 |
+
style_encoder:
|
| 42 |
+
dim: 192
|
| 43 |
+
campplus_path: "campplus_cn_common.bin"
|
| 44 |
+
|
| 45 |
+
length_regulator:
|
| 46 |
+
channels: 384
|
| 47 |
+
is_discrete: false
|
| 48 |
+
in_channels: 1024
|
| 49 |
+
content_codebook_size: 1024
|
| 50 |
+
sampling_ratios: [1, 1, 1, 1]
|
| 51 |
+
vector_quantize: false
|
| 52 |
+
n_codebooks: 2
|
| 53 |
+
quantizer_dropout: 0.0
|
| 54 |
+
f0_condition: false
|
| 55 |
+
n_f0_bins: 512
|
| 56 |
+
|
| 57 |
+
DiT:
|
| 58 |
+
hidden_dim: 384
|
| 59 |
+
num_heads: 6
|
| 60 |
+
depth: 9
|
| 61 |
+
class_dropout_prob: 0.1
|
| 62 |
+
block_size: 8192
|
| 63 |
+
in_channels: 80
|
| 64 |
+
style_condition: true
|
| 65 |
+
final_layer_type: 'mlp'
|
| 66 |
+
target: 'mel' # mel or betavae
|
| 67 |
+
content_dim: 384
|
| 68 |
+
content_codebook_size: 1024
|
| 69 |
+
content_type: 'discrete'
|
| 70 |
+
f0_condition: false
|
| 71 |
+
n_f0_bins: 512
|
| 72 |
+
content_codebooks: 1
|
| 73 |
+
is_causal: false
|
| 74 |
+
long_skip_connection: false
|
| 75 |
+
zero_prompt_speech_token: false # for prompt component, do not input corresponding speech token
|
| 76 |
+
time_as_token: true
|
| 77 |
+
style_as_token: true
|
| 78 |
+
uvit_skip_connection: true
|
| 79 |
+
add_resblock_in_transformer: false
|
| 80 |
+
|
| 81 |
+
loss_params:
|
| 82 |
+
base_lr: 0.0001
|
configs/v2/ar_base.yaml
ADDED
|
File without changes
|
configs/v2/dit_small.yaml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: modules.v2.cfm.CFM
|
| 2 |
+
estimator:
|
| 3 |
+
_target_: modules.v2.dit_wrapper.DiT
|
| 4 |
+
time_as_token: true
|
| 5 |
+
style_as_token: true
|
| 6 |
+
uvit_skip_connection: false
|
| 7 |
+
block_size: 8192
|
| 8 |
+
depth: 13
|
| 9 |
+
num_heads: 8
|
| 10 |
+
hidden_dim: 512
|
| 11 |
+
in_channels: 80
|
| 12 |
+
content_dim: 512
|
| 13 |
+
style_encoder_dim: 192
|
| 14 |
+
class_dropout_prob: 0.1
|
| 15 |
+
dropout_rate: 0.0
|
| 16 |
+
attn_dropout_rate: 0.0
|
| 17 |
+
|
configs/v2/vc_wrapper.yaml
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: modules.v2.vc_wrapper.VoiceConversionWrapper
|
| 2 |
+
sr: 22050
|
| 3 |
+
hop_size: 256
|
| 4 |
+
mel_fn:
|
| 5 |
+
_target_: modules.audio.mel_spectrogram
|
| 6 |
+
_partial_: true
|
| 7 |
+
n_fft: 1024
|
| 8 |
+
win_size: 1024
|
| 9 |
+
hop_size: 256
|
| 10 |
+
num_mels: 80
|
| 11 |
+
sampling_rate: 22050
|
| 12 |
+
fmin: 0
|
| 13 |
+
fmax: null
|
| 14 |
+
center: False
|
| 15 |
+
cfm:
|
| 16 |
+
_target_: modules.v2.cfm.CFM
|
| 17 |
+
estimator:
|
| 18 |
+
_target_: modules.v2.dit_wrapper.DiT
|
| 19 |
+
time_as_token: true
|
| 20 |
+
style_as_token: true
|
| 21 |
+
uvit_skip_connection: false
|
| 22 |
+
block_size: 8192
|
| 23 |
+
depth: 13
|
| 24 |
+
num_heads: 8
|
| 25 |
+
hidden_dim: 512
|
| 26 |
+
in_channels: 80
|
| 27 |
+
content_dim: 512
|
| 28 |
+
style_encoder_dim: 192
|
| 29 |
+
class_dropout_prob: 0.1
|
| 30 |
+
dropout_rate: 0.0
|
| 31 |
+
attn_dropout_rate: 0.0
|
| 32 |
+
cfm_length_regulator:
|
| 33 |
+
_target_: modules.v2.length_regulator.InterpolateRegulator
|
| 34 |
+
channels: 512
|
| 35 |
+
is_discrete: true
|
| 36 |
+
codebook_size: 2048
|
| 37 |
+
sampling_ratios: [ 1, 1, 1, 1 ]
|
| 38 |
+
f0_condition: false
|
| 39 |
+
ar:
|
| 40 |
+
_target_: modules.v2.ar.NaiveWrapper
|
| 41 |
+
model:
|
| 42 |
+
_target_: modules.v2.ar.NaiveTransformer
|
| 43 |
+
config:
|
| 44 |
+
_target_: modules.v2.ar.NaiveModelArgs
|
| 45 |
+
dropout: 0.0
|
| 46 |
+
rope_base: 10000.0
|
| 47 |
+
dim: 768
|
| 48 |
+
head_dim: 64
|
| 49 |
+
n_local_heads: 2
|
| 50 |
+
intermediate_size: 2304
|
| 51 |
+
n_head: 12
|
| 52 |
+
n_layer: 12
|
| 53 |
+
vocab_size: 2049 # 1 + 1 for eos
|
| 54 |
+
ar_length_regulator:
|
| 55 |
+
_target_: modules.v2.length_regulator.InterpolateRegulator
|
| 56 |
+
channels: 768
|
| 57 |
+
is_discrete: true
|
| 58 |
+
codebook_size: 32
|
| 59 |
+
sampling_ratios: [ ]
|
| 60 |
+
f0_condition: false
|
| 61 |
+
style_encoder:
|
| 62 |
+
_target_: modules.campplus.DTDNN.CAMPPlus
|
| 63 |
+
feat_dim: 80
|
| 64 |
+
embedding_size: 192
|
| 65 |
+
content_extractor_narrow:
|
| 66 |
+
_target_: modules.astral_quantization.default_model.AstralQuantizer
|
| 67 |
+
tokenizer_name: "openai/whisper-small"
|
| 68 |
+
ssl_model_name: "facebook/hubert-large-ll60k"
|
| 69 |
+
ssl_output_layer: 18
|
| 70 |
+
skip_ssl: true
|
| 71 |
+
encoder: &bottleneck_encoder
|
| 72 |
+
_target_: modules.astral_quantization.convnext.ConvNeXtV2Stage
|
| 73 |
+
dim: 512
|
| 74 |
+
num_blocks: 12
|
| 75 |
+
intermediate_dim: 1536
|
| 76 |
+
dilation: 1
|
| 77 |
+
input_dim: 1024
|
| 78 |
+
quantizer:
|
| 79 |
+
_target_: modules.astral_quantization.bsq.BinarySphericalQuantize
|
| 80 |
+
codebook_size: 32 # codebook size, must be a power of 2
|
| 81 |
+
dim: 512
|
| 82 |
+
entropy_loss_weight: 0.1
|
| 83 |
+
diversity_gamma: 1.0
|
| 84 |
+
spherical: True
|
| 85 |
+
enable_entropy_loss: True
|
| 86 |
+
soft_entropy_loss: True
|
| 87 |
+
content_extractor_wide:
|
| 88 |
+
_target_: modules.astral_quantization.default_model.AstralQuantizer
|
| 89 |
+
tokenizer_name: "openai/whisper-small"
|
| 90 |
+
ssl_model_name: "facebook/hubert-large-ll60k"
|
| 91 |
+
ssl_output_layer: 18
|
| 92 |
+
encoder: *bottleneck_encoder
|
| 93 |
+
quantizer:
|
| 94 |
+
_target_: modules.astral_quantization.bsq.BinarySphericalQuantize
|
| 95 |
+
codebook_size: 2048 # codebook size, must be a power of 2
|
| 96 |
+
dim: 512
|
| 97 |
+
entropy_loss_weight: 0.1
|
| 98 |
+
diversity_gamma: 1.0
|
| 99 |
+
spherical: True
|
| 100 |
+
enable_entropy_loss: True
|
| 101 |
+
soft_entropy_loss: True
|
| 102 |
+
vocoder:
|
| 103 |
+
_target_: modules.bigvgan.bigvgan.BigVGAN.from_pretrained
|
| 104 |
+
pretrained_model_name_or_path: "nvidia/bigvgan_v2_22khz_80band_256x"
|
| 105 |
+
use_cuda_kernel: false
|
hf_utils.py
CHANGED
|
@@ -2,7 +2,7 @@ import os
|
|
| 2 |
from huggingface_hub import hf_hub_download
|
| 3 |
|
| 4 |
|
| 5 |
-
def load_custom_model_from_hf(repo_id, model_filename="pytorch_model.bin", config_filename=
|
| 6 |
os.makedirs("./checkpoints", exist_ok=True)
|
| 7 |
model_path = hf_hub_download(repo_id=repo_id, filename=model_filename, cache_dir="./checkpoints")
|
| 8 |
if config_filename is None:
|
|
|
|
| 2 |
from huggingface_hub import hf_hub_download
|
| 3 |
|
| 4 |
|
| 5 |
+
def load_custom_model_from_hf(repo_id, model_filename="pytorch_model.bin", config_filename=None):
|
| 6 |
os.makedirs("./checkpoints", exist_ok=True)
|
| 7 |
model_path = hf_hub_download(repo_id=repo_id, filename=model_filename, cache_dir="./checkpoints")
|
| 8 |
if config_filename is None:
|
modules/__pycache__/audio.cpython-310.pyc
CHANGED
|
Binary files a/modules/__pycache__/audio.cpython-310.pyc and b/modules/__pycache__/audio.cpython-310.pyc differ
|
|
|
modules/__pycache__/commons.cpython-310.pyc
CHANGED
|
Binary files a/modules/__pycache__/commons.cpython-310.pyc and b/modules/__pycache__/commons.cpython-310.pyc differ
|
|
|
modules/__pycache__/commons.cpython-38.pyc
ADDED
|
Binary file (14.2 kB). View file
|
|
|
modules/__pycache__/diffusion_transformer.cpython-310.pyc
CHANGED
|
Binary files a/modules/__pycache__/diffusion_transformer.cpython-310.pyc and b/modules/__pycache__/diffusion_transformer.cpython-310.pyc differ
|
|
|
modules/__pycache__/flow_matching.cpython-310.pyc
CHANGED
|
Binary files a/modules/__pycache__/flow_matching.cpython-310.pyc and b/modules/__pycache__/flow_matching.cpython-310.pyc differ
|
|
|
modules/__pycache__/length_regulator.cpython-310.pyc
CHANGED
|
Binary files a/modules/__pycache__/length_regulator.cpython-310.pyc and b/modules/__pycache__/length_regulator.cpython-310.pyc differ
|
|
|
modules/__pycache__/rmvpe.cpython-310.pyc
ADDED
|
Binary file (17.6 kB). View file
|
|
|
modules/astral_quantization/__pycache__/bsq.cpython-310.pyc
ADDED
|
Binary file (12.7 kB). View file
|
|
|
modules/astral_quantization/__pycache__/convnext.cpython-310.pyc
ADDED
|
Binary file (6.87 kB). View file
|
|
|
modules/astral_quantization/__pycache__/default_model.cpython-310.pyc
ADDED
|
Binary file (2.8 kB). View file
|
|
|
modules/astral_quantization/bsq.py
ADDED
|
@@ -0,0 +1,569 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Lookup Free Quantization
|
| 3 |
+
Proposed in https://arxiv.org/abs/2310.05737
|
| 4 |
+
|
| 5 |
+
In the simplest setup, each dimension is quantized into {-1, 1}.
|
| 6 |
+
An entropy penalty is used to encourage utilization.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from math import log2, ceil
|
| 10 |
+
from functools import partial, cache
|
| 11 |
+
from collections import namedtuple
|
| 12 |
+
from contextlib import nullcontext
|
| 13 |
+
|
| 14 |
+
import torch.distributed as dist
|
| 15 |
+
from torch.distributed import nn as dist_nn
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
from torch import nn, einsum
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
from torch.nn import Module
|
| 21 |
+
from torch.amp import autocast
|
| 22 |
+
|
| 23 |
+
from einops import rearrange, reduce, pack, unpack
|
| 24 |
+
|
| 25 |
+
# constants
|
| 26 |
+
|
| 27 |
+
Return = namedtuple('Return', ['quantized', 'indices', 'entropy_aux_loss'])
|
| 28 |
+
|
| 29 |
+
LossBreakdown = namedtuple('LossBreakdown', ['per_sample_entropy', 'batch_entropy', 'commitment'])
|
| 30 |
+
|
| 31 |
+
# distributed helpers
|
| 32 |
+
|
| 33 |
+
@cache
|
| 34 |
+
def is_distributed():
|
| 35 |
+
return dist.is_initialized() and dist.get_world_size() > 1
|
| 36 |
+
|
| 37 |
+
def maybe_distributed_mean(t):
|
| 38 |
+
if not is_distributed():
|
| 39 |
+
return t
|
| 40 |
+
|
| 41 |
+
dist_nn.all_reduce(t)
|
| 42 |
+
t = t / dist.get_world_size()
|
| 43 |
+
return t
|
| 44 |
+
|
| 45 |
+
# helper functions
|
| 46 |
+
|
| 47 |
+
def exists(v):
|
| 48 |
+
return v is not None
|
| 49 |
+
|
| 50 |
+
def identity(t):
|
| 51 |
+
return t
|
| 52 |
+
|
| 53 |
+
def default(*args):
|
| 54 |
+
for arg in args:
|
| 55 |
+
if exists(arg):
|
| 56 |
+
return arg() if callable(arg) else arg
|
| 57 |
+
return None
|
| 58 |
+
|
| 59 |
+
def pack_one(t, pattern):
|
| 60 |
+
return pack([t], pattern)
|
| 61 |
+
|
| 62 |
+
def unpack_one(t, ps, pattern):
|
| 63 |
+
return unpack(t, ps, pattern)[0]
|
| 64 |
+
|
| 65 |
+
def l2norm(t):
|
| 66 |
+
return F.normalize(t, dim = -1)
|
| 67 |
+
|
| 68 |
+
# entropy
|
| 69 |
+
|
| 70 |
+
def log(t, eps = 1e-5):
|
| 71 |
+
return t.clamp(min = eps).log()
|
| 72 |
+
|
| 73 |
+
def entropy(prob):
|
| 74 |
+
return (-prob * log(prob)).sum(dim=-1)
|
| 75 |
+
|
| 76 |
+
# cosine sim linear
|
| 77 |
+
|
| 78 |
+
class CosineSimLinear(Module):
|
| 79 |
+
def __init__(
|
| 80 |
+
self,
|
| 81 |
+
dim_in,
|
| 82 |
+
dim_out,
|
| 83 |
+
scale = 1.
|
| 84 |
+
):
|
| 85 |
+
super().__init__()
|
| 86 |
+
self.scale = scale
|
| 87 |
+
self.weight = nn.Parameter(torch.randn(dim_in, dim_out))
|
| 88 |
+
|
| 89 |
+
def forward(self, x):
|
| 90 |
+
x = F.normalize(x, dim = -1)
|
| 91 |
+
w = F.normalize(self.weight, dim = 0)
|
| 92 |
+
return (x @ w) * self.scale
|
| 93 |
+
|
| 94 |
+
def soft_entropy_loss(u, tau=1.0, gamma=1.0):
|
| 95 |
+
"""
|
| 96 |
+
Compute the soft entropy loss for Binary Spherical Quantization (BSQ).
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
u (torch.Tensor): Input latent embeddings of shape (batch_size, L).
|
| 100 |
+
tau (float): Temperature scaling factor.
|
| 101 |
+
gamma (float): Weight for the second entropy term.
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
torch.Tensor: Soft entropy loss.
|
| 105 |
+
"""
|
| 106 |
+
# Binary quantization: Generate implicit codebook corners
|
| 107 |
+
L = u.size(1) # Dimensionality of codebook
|
| 108 |
+
corners = torch.tensor([-1.0, 1.0], device=u.device) / (L**0.5)
|
| 109 |
+
|
| 110 |
+
# Compute soft quantization probabilities for all dimensions
|
| 111 |
+
# q_hat(c|u) for each dimension
|
| 112 |
+
prob_matrix = torch.sigmoid(2 * tau * corners.unsqueeze(1) * u.unsqueeze(2)) # Shape: (batch_size, L, 2)
|
| 113 |
+
|
| 114 |
+
# Entropy of q_hat(c|u) (independent along each dimension)
|
| 115 |
+
entropy_per_dim = -torch.sum(prob_matrix * prob_matrix.log(), dim=-1) # Shape: (batch_size, L)
|
| 116 |
+
entropy_term1 = entropy_per_dim.mean()
|
| 117 |
+
|
| 118 |
+
# Expected probabilities for dataset entropy (approximation)
|
| 119 |
+
expected_probs = prob_matrix.mean(dim=0) # Mean across batch, shape: (L, 2)
|
| 120 |
+
entropy_term2 = -torch.sum(expected_probs * expected_probs.log(), dim=-1).mean()
|
| 121 |
+
|
| 122 |
+
# Final entropy loss
|
| 123 |
+
loss = entropy_term1 - gamma * entropy_term2
|
| 124 |
+
return loss
|
| 125 |
+
|
| 126 |
+
# class
|
| 127 |
+
|
| 128 |
+
class BinarySphericalQuantize(Module):
|
| 129 |
+
def __init__(
|
| 130 |
+
self,
|
| 131 |
+
*,
|
| 132 |
+
dim = None,
|
| 133 |
+
codebook_size = None,
|
| 134 |
+
entropy_loss_weight = 0.1,
|
| 135 |
+
commitment_loss_weight = 0.,
|
| 136 |
+
diversity_gamma = 1.,
|
| 137 |
+
straight_through_activation = nn.Identity(),
|
| 138 |
+
num_codebooks = 1,
|
| 139 |
+
keep_num_codebooks_dim = None,
|
| 140 |
+
codebook_scale = 1., # for residual LFQ, codebook scaled down by 2x at each layer
|
| 141 |
+
frac_per_sample_entropy = 0.25, # make less than 1. to only use a random fraction of the probs for per sample entropy
|
| 142 |
+
has_projections = None,
|
| 143 |
+
projection_has_bias = True,
|
| 144 |
+
soft_clamp_input_value = None,
|
| 145 |
+
cosine_sim_project_in = False,
|
| 146 |
+
cosine_sim_project_in_scale = None,
|
| 147 |
+
channel_first = None,
|
| 148 |
+
experimental_softplus_entropy_loss = False,
|
| 149 |
+
entropy_loss_offset = 5., # how much to shift the loss before softplus
|
| 150 |
+
spherical = True, # from https://arxiv.org/abs/2406.07548
|
| 151 |
+
force_quantization_f32 = True, # will force the quantization step to be full precision
|
| 152 |
+
enable_entropy_loss = True,
|
| 153 |
+
soft_entropy_loss = True,
|
| 154 |
+
):
|
| 155 |
+
super().__init__()
|
| 156 |
+
|
| 157 |
+
# some assert validations
|
| 158 |
+
|
| 159 |
+
assert exists(dim) or exists(codebook_size), 'either dim or codebook_size must be specified for LFQ'
|
| 160 |
+
assert not exists(codebook_size) or log2(codebook_size).is_integer(), f'your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})'
|
| 161 |
+
|
| 162 |
+
codebook_size = default(codebook_size, lambda: 2 ** dim)
|
| 163 |
+
self.codebook_size = codebook_size
|
| 164 |
+
|
| 165 |
+
codebook_dim = int(log2(codebook_size))
|
| 166 |
+
codebook_dims = codebook_dim * num_codebooks
|
| 167 |
+
dim = default(dim, codebook_dims)
|
| 168 |
+
|
| 169 |
+
has_projections = default(has_projections, dim != codebook_dims)
|
| 170 |
+
|
| 171 |
+
if cosine_sim_project_in:
|
| 172 |
+
cosine_sim_project_in = default(cosine_sim_project_in_scale, codebook_scale)
|
| 173 |
+
project_in_klass = partial(CosineSimLinear, scale = cosine_sim_project_in)
|
| 174 |
+
else:
|
| 175 |
+
project_in_klass = partial(nn.Linear, bias = projection_has_bias)
|
| 176 |
+
|
| 177 |
+
self.project_in = project_in_klass(dim, codebook_dims) if has_projections else nn.Identity()
|
| 178 |
+
self.project_out = nn.Linear(codebook_dims, dim, bias = projection_has_bias) if has_projections else nn.Identity()
|
| 179 |
+
self.has_projections = has_projections
|
| 180 |
+
|
| 181 |
+
self.dim = dim
|
| 182 |
+
self.codebook_dim = codebook_dim
|
| 183 |
+
self.num_codebooks = num_codebooks
|
| 184 |
+
|
| 185 |
+
keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
|
| 186 |
+
assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
|
| 187 |
+
self.keep_num_codebooks_dim = keep_num_codebooks_dim
|
| 188 |
+
|
| 189 |
+
# channel first
|
| 190 |
+
|
| 191 |
+
self.channel_first = channel_first
|
| 192 |
+
|
| 193 |
+
# straight through activation
|
| 194 |
+
|
| 195 |
+
self.activation = straight_through_activation
|
| 196 |
+
|
| 197 |
+
# whether to use BSQ (binary spherical quantization)
|
| 198 |
+
|
| 199 |
+
self.spherical = spherical
|
| 200 |
+
self.maybe_l2norm = (lambda t: l2norm(t) * self.codebook_scale) if spherical else identity
|
| 201 |
+
|
| 202 |
+
# entropy aux loss related weights
|
| 203 |
+
|
| 204 |
+
assert 0 < frac_per_sample_entropy <= 1.
|
| 205 |
+
self.frac_per_sample_entropy = frac_per_sample_entropy
|
| 206 |
+
|
| 207 |
+
self.diversity_gamma = diversity_gamma
|
| 208 |
+
self.entropy_loss_weight = entropy_loss_weight
|
| 209 |
+
|
| 210 |
+
# codebook scale
|
| 211 |
+
|
| 212 |
+
self.codebook_scale = codebook_scale
|
| 213 |
+
|
| 214 |
+
# commitment loss
|
| 215 |
+
|
| 216 |
+
self.commitment_loss_weight = commitment_loss_weight
|
| 217 |
+
|
| 218 |
+
# whether to soft clamp the input value from -value to value
|
| 219 |
+
|
| 220 |
+
self.soft_clamp_input_value = soft_clamp_input_value
|
| 221 |
+
assert not exists(soft_clamp_input_value) or soft_clamp_input_value >= codebook_scale
|
| 222 |
+
|
| 223 |
+
# whether to make the entropy loss positive through a softplus (experimental, please report if this worked or not in discussions)
|
| 224 |
+
|
| 225 |
+
self.entropy_loss_offset = entropy_loss_offset
|
| 226 |
+
self.experimental_softplus_entropy_loss = experimental_softplus_entropy_loss
|
| 227 |
+
|
| 228 |
+
# for no auxiliary loss, during inference
|
| 229 |
+
|
| 230 |
+
self.register_buffer('mask', 2 ** torch.arange(codebook_dim - 1, -1, -1))
|
| 231 |
+
self.register_buffer('zero', torch.tensor(0.), persistent = False)
|
| 232 |
+
|
| 233 |
+
# whether to force quantization step to be f32
|
| 234 |
+
|
| 235 |
+
self.force_quantization_f32 = force_quantization_f32
|
| 236 |
+
|
| 237 |
+
# codes
|
| 238 |
+
self.enable_entropy_loss = enable_entropy_loss
|
| 239 |
+
self.soft_entropy_loss = soft_entropy_loss
|
| 240 |
+
if codebook_size <= 100000:
|
| 241 |
+
all_codes = torch.arange(codebook_size)
|
| 242 |
+
bits = ((all_codes[..., None].int() & self.mask) != 0).float()
|
| 243 |
+
codebook = self.bits_to_codes(bits)
|
| 244 |
+
|
| 245 |
+
self.register_buffer('codebook', codebook.float(), persistent = False)
|
| 246 |
+
else:
|
| 247 |
+
all_codes = torch.arange(pow(2, 16))
|
| 248 |
+
mask = 2 ** torch.arange(16 - 1, -1, -1)
|
| 249 |
+
bits = ((all_codes[..., None].int() & mask) != 0).float()
|
| 250 |
+
codebook = self.bits_to_codes(bits)
|
| 251 |
+
|
| 252 |
+
self.register_buffer('codebook', codebook.float(), persistent = False)
|
| 253 |
+
|
| 254 |
+
def bits_to_codes(self, bits):
|
| 255 |
+
return bits * self.codebook_scale * 2 - self.codebook_scale
|
| 256 |
+
|
| 257 |
+
@property
|
| 258 |
+
def dtype(self):
|
| 259 |
+
return self.codebook.dtype
|
| 260 |
+
|
| 261 |
+
def indices_to_codes(
|
| 262 |
+
self,
|
| 263 |
+
indices,
|
| 264 |
+
project_out = True
|
| 265 |
+
):
|
| 266 |
+
is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
|
| 267 |
+
should_transpose = default(self.channel_first, is_img_or_video)
|
| 268 |
+
|
| 269 |
+
if not self.keep_num_codebooks_dim:
|
| 270 |
+
indices = rearrange(indices, '... -> ... 1')
|
| 271 |
+
|
| 272 |
+
# indices to codes, which are bits of either -1 or 1
|
| 273 |
+
|
| 274 |
+
bits = ((indices[..., None].int() & self.mask) != 0).to(self.dtype)
|
| 275 |
+
|
| 276 |
+
codes = self.bits_to_codes(bits)
|
| 277 |
+
|
| 278 |
+
codes = self.maybe_l2norm(codes)
|
| 279 |
+
|
| 280 |
+
codes = rearrange(codes, '... c d -> ... (c d)')
|
| 281 |
+
|
| 282 |
+
# whether to project codes out to original dimensions
|
| 283 |
+
# if the input feature dimensions were not log2(codebook size)
|
| 284 |
+
|
| 285 |
+
if project_out:
|
| 286 |
+
codes = self.project_out(codes)
|
| 287 |
+
|
| 288 |
+
# rearrange codes back to original shape
|
| 289 |
+
|
| 290 |
+
if should_transpose:
|
| 291 |
+
codes = rearrange(codes, 'b ... d -> b d ...')
|
| 292 |
+
|
| 293 |
+
return codes
|
| 294 |
+
|
| 295 |
+
def bits_to_z(self, bits):
|
| 296 |
+
# assert bits must contain only -1 and 1
|
| 297 |
+
assert torch.all(bits.abs() == 1)
|
| 298 |
+
quantized = bits.float()
|
| 299 |
+
quantized = self.maybe_l2norm(quantized)
|
| 300 |
+
z = self.project_out(quantized)
|
| 301 |
+
return z
|
| 302 |
+
|
| 303 |
+
def forward(
|
| 304 |
+
self,
|
| 305 |
+
x,
|
| 306 |
+
inv_temperature = 100.,
|
| 307 |
+
return_loss_breakdown = False,
|
| 308 |
+
mask = None,
|
| 309 |
+
return_bits = False
|
| 310 |
+
):
|
| 311 |
+
"""
|
| 312 |
+
einstein notation
|
| 313 |
+
b - batch
|
| 314 |
+
n - sequence (or flattened spatial dimensions)
|
| 315 |
+
d - feature dimension, which is also log2(codebook size)
|
| 316 |
+
c - number of codebook dim
|
| 317 |
+
"""
|
| 318 |
+
|
| 319 |
+
is_img_or_video = x.ndim >= 4
|
| 320 |
+
should_transpose = default(self.channel_first, is_img_or_video)
|
| 321 |
+
|
| 322 |
+
# standardize image or video into (batch, seq, dimension)
|
| 323 |
+
|
| 324 |
+
if should_transpose:
|
| 325 |
+
x = rearrange(x, 'b d ... -> b ... d')
|
| 326 |
+
x, ps = pack_one(x, 'b * d')
|
| 327 |
+
|
| 328 |
+
assert x.shape[-1] == self.dim, f'expected dimension of {self.dim} but received {x.shape[-1]}'
|
| 329 |
+
|
| 330 |
+
x = self.project_in(x)
|
| 331 |
+
|
| 332 |
+
# maybe soft clamp
|
| 333 |
+
|
| 334 |
+
if exists(self.soft_clamp_input_value):
|
| 335 |
+
clamp_value = self.soft_clamp_input_value
|
| 336 |
+
x = (x / clamp_value).tanh() * clamp_value
|
| 337 |
+
|
| 338 |
+
# split out number of codebooks
|
| 339 |
+
|
| 340 |
+
x = rearrange(x, 'b n (c d) -> b n c d', c = self.num_codebooks)
|
| 341 |
+
|
| 342 |
+
# maybe l2norm
|
| 343 |
+
|
| 344 |
+
x = self.maybe_l2norm(x)
|
| 345 |
+
|
| 346 |
+
# whether to force quantization step to be full precision or not
|
| 347 |
+
|
| 348 |
+
force_f32 = self.force_quantization_f32
|
| 349 |
+
|
| 350 |
+
quantization_context = partial(autocast, 'cuda', enabled = False) if force_f32 else nullcontext
|
| 351 |
+
|
| 352 |
+
with quantization_context():
|
| 353 |
+
|
| 354 |
+
if force_f32:
|
| 355 |
+
orig_dtype = x.dtype
|
| 356 |
+
x = x.float()
|
| 357 |
+
|
| 358 |
+
# quantize by eq 3.
|
| 359 |
+
|
| 360 |
+
original_input = x
|
| 361 |
+
|
| 362 |
+
codebook_value = torch.ones_like(x) * self.codebook_scale
|
| 363 |
+
quantized = torch.where(x > 0, codebook_value, -codebook_value)
|
| 364 |
+
if return_bits:
|
| 365 |
+
return quantized
|
| 366 |
+
|
| 367 |
+
# calculate indices
|
| 368 |
+
|
| 369 |
+
indices = reduce((quantized > 0).int() * self.mask.int(), 'b n c d -> b n c', 'sum')
|
| 370 |
+
|
| 371 |
+
# maybe l2norm
|
| 372 |
+
|
| 373 |
+
quantized = self.maybe_l2norm(quantized)
|
| 374 |
+
|
| 375 |
+
# use straight-through gradients (optionally with custom activation fn) if training
|
| 376 |
+
|
| 377 |
+
if self.training:
|
| 378 |
+
x = self.activation(x)
|
| 379 |
+
x = x + (quantized - x).detach()
|
| 380 |
+
else:
|
| 381 |
+
x = quantized
|
| 382 |
+
|
| 383 |
+
# entropy aux loss
|
| 384 |
+
if self.soft_entropy_loss:
|
| 385 |
+
entropy_aux_loss = soft_entropy_loss(x, tau=1.0, gamma=1.0)
|
| 386 |
+
elif self.training and self.enable_entropy_loss:
|
| 387 |
+
|
| 388 |
+
if force_f32:
|
| 389 |
+
codebook = self.codebook.float()
|
| 390 |
+
|
| 391 |
+
codebook = self.maybe_l2norm(codebook)
|
| 392 |
+
|
| 393 |
+
# whether to only use a fraction of probs, for reducing memory
|
| 394 |
+
|
| 395 |
+
if self.frac_per_sample_entropy < 1.:
|
| 396 |
+
# account for mask
|
| 397 |
+
if exists(mask):
|
| 398 |
+
original_input = original_input[mask]
|
| 399 |
+
original_input = rearrange(original_input, 'b n ... -> (b n) ...')
|
| 400 |
+
|
| 401 |
+
rand_mask = torch.randn(self.codebook_dim).argsort(dim = -1) < 16
|
| 402 |
+
|
| 403 |
+
sampled_input = original_input[..., rand_mask]
|
| 404 |
+
|
| 405 |
+
sampled_distance = -2 * einsum('... i d, j d -> ... i j', sampled_input, codebook)
|
| 406 |
+
|
| 407 |
+
sampled_prob = (-sampled_distance * inv_temperature).softmax(dim = -1)
|
| 408 |
+
|
| 409 |
+
per_sample_probs = sampled_prob
|
| 410 |
+
else:
|
| 411 |
+
if exists(mask):
|
| 412 |
+
original_input = original_input[mask]
|
| 413 |
+
original_input = rearrange(original_input, 'b n ... -> (b n) ...')
|
| 414 |
+
# the same as euclidean distance up to a constant
|
| 415 |
+
distance = -2 * einsum('... i d, j d -> ... i j', original_input, codebook)
|
| 416 |
+
|
| 417 |
+
prob = (-distance * inv_temperature).softmax(dim = -1)
|
| 418 |
+
|
| 419 |
+
per_sample_probs = prob
|
| 420 |
+
|
| 421 |
+
# calculate per sample entropy
|
| 422 |
+
|
| 423 |
+
per_sample_entropy = entropy(per_sample_probs).mean()
|
| 424 |
+
|
| 425 |
+
# distribution over all available tokens in the batch
|
| 426 |
+
|
| 427 |
+
avg_prob = reduce(per_sample_probs, '... c d -> c d', 'mean')
|
| 428 |
+
|
| 429 |
+
avg_prob = maybe_distributed_mean(avg_prob)
|
| 430 |
+
|
| 431 |
+
codebook_entropy = entropy(avg_prob).mean()
|
| 432 |
+
|
| 433 |
+
# 1. entropy will be nudged to be low for each code, to encourage the network to output confident predictions
|
| 434 |
+
# 2. codebook entropy will be nudged to be high, to encourage all codes to be uniformly used within the batch
|
| 435 |
+
|
| 436 |
+
entropy_aux_loss = per_sample_entropy - self.diversity_gamma * codebook_entropy
|
| 437 |
+
else:
|
| 438 |
+
# if not training, just return dummy 0
|
| 439 |
+
entropy_aux_loss = per_sample_entropy = codebook_entropy = self.zero
|
| 440 |
+
|
| 441 |
+
# whether to make the entropy loss positive or not through a (shifted) softplus
|
| 442 |
+
|
| 443 |
+
if self.training and self.experimental_softplus_entropy_loss:
|
| 444 |
+
entropy_aux_loss = F.softplus(entropy_aux_loss + self.entropy_loss_offset)
|
| 445 |
+
|
| 446 |
+
# commit loss
|
| 447 |
+
|
| 448 |
+
if self.training and self.commitment_loss_weight > 0.:
|
| 449 |
+
|
| 450 |
+
commit_loss = F.mse_loss(original_input, quantized.detach(), reduction = 'none')
|
| 451 |
+
|
| 452 |
+
if exists(mask):
|
| 453 |
+
commit_loss = commit_loss[mask]
|
| 454 |
+
|
| 455 |
+
commit_loss = commit_loss.mean()
|
| 456 |
+
else:
|
| 457 |
+
commit_loss = self.zero
|
| 458 |
+
|
| 459 |
+
# input back to original dtype if needed
|
| 460 |
+
|
| 461 |
+
if force_f32:
|
| 462 |
+
x = x.type(orig_dtype)
|
| 463 |
+
|
| 464 |
+
# merge back codebook dim
|
| 465 |
+
|
| 466 |
+
x = rearrange(x, 'b n c d -> b n (c d)')
|
| 467 |
+
|
| 468 |
+
# project out to feature dimension if needed
|
| 469 |
+
|
| 470 |
+
x = self.project_out(x)
|
| 471 |
+
|
| 472 |
+
# reconstitute image or video dimensions
|
| 473 |
+
|
| 474 |
+
if should_transpose:
|
| 475 |
+
x = unpack_one(x, ps, 'b * d')
|
| 476 |
+
x = rearrange(x, 'b ... d -> b d ...')
|
| 477 |
+
|
| 478 |
+
indices = unpack_one(indices, ps, 'b * c')
|
| 479 |
+
|
| 480 |
+
# whether to remove single codebook dim
|
| 481 |
+
|
| 482 |
+
if not self.keep_num_codebooks_dim:
|
| 483 |
+
indices = rearrange(indices, '... 1 -> ...')
|
| 484 |
+
|
| 485 |
+
# complete aux loss
|
| 486 |
+
|
| 487 |
+
aux_loss = entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight
|
| 488 |
+
|
| 489 |
+
# returns
|
| 490 |
+
|
| 491 |
+
ret = Return(x, indices, aux_loss)
|
| 492 |
+
|
| 493 |
+
if not return_loss_breakdown:
|
| 494 |
+
return ret
|
| 495 |
+
|
| 496 |
+
return ret, LossBreakdown(per_sample_entropy, codebook_entropy, commit_loss)
|
| 497 |
+
|
| 498 |
+
class GroupedResidualBSQ(Module):
|
| 499 |
+
def __init__(
|
| 500 |
+
self,
|
| 501 |
+
*,
|
| 502 |
+
dim,
|
| 503 |
+
groups = 1,
|
| 504 |
+
accept_image_fmap = False,
|
| 505 |
+
**kwargs
|
| 506 |
+
):
|
| 507 |
+
super().__init__()
|
| 508 |
+
self.dim = dim
|
| 509 |
+
self.groups = groups
|
| 510 |
+
assert (dim % groups) == 0
|
| 511 |
+
dim_per_group = dim // groups
|
| 512 |
+
|
| 513 |
+
self.accept_image_fmap = accept_image_fmap
|
| 514 |
+
|
| 515 |
+
self.rvqs = nn.ModuleList([])
|
| 516 |
+
|
| 517 |
+
for _ in range(groups):
|
| 518 |
+
self.rvqs.append(LFQ(
|
| 519 |
+
dim = dim_per_group,
|
| 520 |
+
**kwargs
|
| 521 |
+
))
|
| 522 |
+
|
| 523 |
+
self.codebook_size = self.rvqs[0].codebook_size
|
| 524 |
+
|
| 525 |
+
@property
|
| 526 |
+
def codebooks(self):
|
| 527 |
+
return torch.stack(tuple(rvq.codebooks for rvq in self.rvqs))
|
| 528 |
+
|
| 529 |
+
@property
|
| 530 |
+
def split_dim(self):
|
| 531 |
+
return 1 if self.accept_image_fmap else -1
|
| 532 |
+
|
| 533 |
+
def get_codes_from_indices(self, indices):
|
| 534 |
+
codes = tuple(rvq.get_codes_from_indices(chunk_indices) for rvq, chunk_indices in zip(self.rvqs, indices))
|
| 535 |
+
return torch.stack(codes)
|
| 536 |
+
|
| 537 |
+
def get_output_from_indices(self, indices):
|
| 538 |
+
outputs = tuple(rvq.get_output_from_indices(chunk_indices) for rvq, chunk_indices in zip(self.rvqs, indices))
|
| 539 |
+
return torch.cat(outputs, dim = self.split_dim)
|
| 540 |
+
|
| 541 |
+
def forward(
|
| 542 |
+
self,
|
| 543 |
+
x,
|
| 544 |
+
return_all_codes = False
|
| 545 |
+
):
|
| 546 |
+
shape, split_dim = x.shape, self.split_dim
|
| 547 |
+
assert shape[split_dim] == self.dim
|
| 548 |
+
|
| 549 |
+
# split the feature dimension into groups
|
| 550 |
+
|
| 551 |
+
x = x.chunk(self.groups, dim = split_dim)
|
| 552 |
+
|
| 553 |
+
forward_kwargs = dict(
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
+
# invoke residual vq on each group
|
| 557 |
+
|
| 558 |
+
out = tuple(rvq(chunk, **forward_kwargs) for rvq, chunk in zip(self.rvqs, x))
|
| 559 |
+
out = tuple(zip(*out))
|
| 560 |
+
|
| 561 |
+
# otherwise, get all the zipped outputs and combine them
|
| 562 |
+
|
| 563 |
+
quantized, all_indices, *maybe_aux_loss = out
|
| 564 |
+
|
| 565 |
+
quantized = torch.cat(quantized, dim = split_dim)
|
| 566 |
+
all_indices = torch.stack(all_indices)
|
| 567 |
+
|
| 568 |
+
ret = (quantized, all_indices, *maybe_aux_loss)
|
| 569 |
+
return ret
|
modules/astral_quantization/convnext.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from typing import List
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ConvNextV2LayerNorm(nn.Module):
|
| 8 |
+
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
| 9 |
+
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
|
| 10 |
+
width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
| 16 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
| 17 |
+
self.eps = eps
|
| 18 |
+
self.data_format = data_format
|
| 19 |
+
if self.data_format not in ["channels_last", "channels_first"]:
|
| 20 |
+
raise NotImplementedError(f"Unsupported data format: {self.data_format}")
|
| 21 |
+
self.normalized_shape = (normalized_shape,)
|
| 22 |
+
|
| 23 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 24 |
+
if self.data_format == "channels_last":
|
| 25 |
+
x = torch.nn.functional.layer_norm(
|
| 26 |
+
x, self.normalized_shape, self.weight, self.bias, self.eps
|
| 27 |
+
)
|
| 28 |
+
elif self.data_format == "channels_first":
|
| 29 |
+
input_dtype = x.dtype
|
| 30 |
+
x = x.float()
|
| 31 |
+
u = x.mean(1, keepdim=True)
|
| 32 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
| 33 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
| 34 |
+
x = x.to(dtype=input_dtype)
|
| 35 |
+
x = self.weight[None, :, None] * x + self.bias[None, :, None]
|
| 36 |
+
return x
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class GRN(nn.Module):
|
| 40 |
+
def __init__(self, dim):
|
| 41 |
+
super().__init__()
|
| 42 |
+
self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
|
| 43 |
+
self.beta = nn.Parameter(torch.zeros(1, 1, dim))
|
| 44 |
+
|
| 45 |
+
def forward(self, x):
|
| 46 |
+
Gx = torch.norm(x, p=2, dim=1, keepdim=True)
|
| 47 |
+
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
|
| 48 |
+
return self.gamma * (x * Nx) + self.beta + x
|
| 49 |
+
|
| 50 |
+
class InterpolationLayer(nn.Module):
|
| 51 |
+
def __init__(self, ): # this is a default of 1 / 50 * (44100 / 512) / 4
|
| 52 |
+
super().__init__()
|
| 53 |
+
pass
|
| 54 |
+
|
| 55 |
+
def forward(self, x: torch.Tensor, target_len: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
| 56 |
+
x = F.interpolate(x, size=target_len, mode='linear')
|
| 57 |
+
return x
|
| 58 |
+
|
| 59 |
+
class ConvNeXtV2Stage(nn.Module):
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
dim: int = 512,
|
| 63 |
+
intermediate_dim: int = 2048,
|
| 64 |
+
num_blocks: int = 1,
|
| 65 |
+
dilation: int = 1,
|
| 66 |
+
downsample_layer_indices: List[int] = None,
|
| 67 |
+
downsample_factors: List[int] = None,
|
| 68 |
+
upsample_layer_indices: List[int] = None,
|
| 69 |
+
upsample_factors: List[int] = None,
|
| 70 |
+
interpolation_layer_indices: List[int] = None,
|
| 71 |
+
input_dim: int = None,
|
| 72 |
+
output_dim: int = None,
|
| 73 |
+
gin_channels: int = 0,
|
| 74 |
+
):
|
| 75 |
+
super().__init__()
|
| 76 |
+
# maybe downsample layers
|
| 77 |
+
if downsample_layer_indices is not None:
|
| 78 |
+
assert downsample_factors is not None
|
| 79 |
+
self.downsample_blocks = nn.ModuleList(
|
| 80 |
+
[
|
| 81 |
+
nn.Sequential(
|
| 82 |
+
ConvNextV2LayerNorm(dim, data_format="channels_first"),
|
| 83 |
+
nn.Conv1d(
|
| 84 |
+
dim, dim, kernel_size=downsample_factor, stride=downsample_factor
|
| 85 |
+
),
|
| 86 |
+
) for _, downsample_factor in zip(downsample_layer_indices, downsample_factors)
|
| 87 |
+
]
|
| 88 |
+
)
|
| 89 |
+
self.downsample_layer_indices = downsample_layer_indices
|
| 90 |
+
else:
|
| 91 |
+
self.downsample_blocks = nn.ModuleList()
|
| 92 |
+
self.downsample_layer_indices = []
|
| 93 |
+
|
| 94 |
+
# maybe upsample layers
|
| 95 |
+
if upsample_layer_indices is not None:
|
| 96 |
+
assert upsample_factors is not None
|
| 97 |
+
self.upsample_blocks = nn.ModuleList(
|
| 98 |
+
[
|
| 99 |
+
nn.Sequential(
|
| 100 |
+
ConvNextV2LayerNorm(dim, data_format="channels_first"),
|
| 101 |
+
nn.ConvTranspose1d(
|
| 102 |
+
dim, dim, kernel_size=upsample_factor, stride=upsample_factor
|
| 103 |
+
),
|
| 104 |
+
) for _, upsample_factor in zip(upsample_layer_indices, upsample_factors)
|
| 105 |
+
]
|
| 106 |
+
)
|
| 107 |
+
self.upsample_layer_indices = upsample_layer_indices
|
| 108 |
+
else:
|
| 109 |
+
self.upsample_blocks = nn.ModuleList()
|
| 110 |
+
self.upsample_layer_indices = []
|
| 111 |
+
|
| 112 |
+
# maybe interpolation layers
|
| 113 |
+
if interpolation_layer_indices is not None:
|
| 114 |
+
self.interpolation_blocks = nn.ModuleList(
|
| 115 |
+
[
|
| 116 |
+
InterpolationLayer()
|
| 117 |
+
for _ in interpolation_layer_indices
|
| 118 |
+
]
|
| 119 |
+
)
|
| 120 |
+
self.interpolation_layer_indices = interpolation_layer_indices
|
| 121 |
+
else:
|
| 122 |
+
self.interpolation_blocks = nn.ModuleList()
|
| 123 |
+
self.interpolation_layer_indices = []
|
| 124 |
+
|
| 125 |
+
# main blocks
|
| 126 |
+
self.blocks = nn.ModuleList(
|
| 127 |
+
[
|
| 128 |
+
ConvNeXtV2Block(
|
| 129 |
+
dim=dim,
|
| 130 |
+
intermediate_dim=intermediate_dim,
|
| 131 |
+
dilation=dilation,
|
| 132 |
+
)
|
| 133 |
+
for _ in range(num_blocks)
|
| 134 |
+
]
|
| 135 |
+
)
|
| 136 |
+
# maybe input and output projections
|
| 137 |
+
if input_dim is not None and input_dim != dim:
|
| 138 |
+
self.input_projection = nn.Conv1d(input_dim, dim, kernel_size=1)
|
| 139 |
+
else:
|
| 140 |
+
self.input_projection = nn.Identity()
|
| 141 |
+
if output_dim is not None and output_dim != dim:
|
| 142 |
+
self.output_projection = nn.Conv1d(dim, output_dim, kernel_size=1)
|
| 143 |
+
else:
|
| 144 |
+
self.output_projection = nn.Identity()
|
| 145 |
+
|
| 146 |
+
if gin_channels > 0:
|
| 147 |
+
self.gin = nn.Conv1d(gin_channels, dim, kernel_size=1)
|
| 148 |
+
|
| 149 |
+
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
| 150 |
+
x = self.input_projection(x) # B, D, T
|
| 151 |
+
if hasattr(self, 'gin'):
|
| 152 |
+
g = kwargs['g']
|
| 153 |
+
x = x + self.gin(g)
|
| 154 |
+
# pad to a multiple of cumprod(downsample_factors)
|
| 155 |
+
if len(self.downsample_blocks) > 0:
|
| 156 |
+
downsample_factor = 1
|
| 157 |
+
for factor in self.downsample_blocks:
|
| 158 |
+
downsample_factor *= factor[1].stride[0]
|
| 159 |
+
pad_len = downsample_factor - x.size(-1) % downsample_factor
|
| 160 |
+
if pad_len > 0:
|
| 161 |
+
x = torch.cat([x, torch.zeros_like(x[:, :, :pad_len])], dim=-1)
|
| 162 |
+
|
| 163 |
+
# main blocks
|
| 164 |
+
for layer_idx, block in enumerate(self.blocks):
|
| 165 |
+
if layer_idx in self.downsample_layer_indices:
|
| 166 |
+
x = self.downsample_blocks[self.downsample_layer_indices.index(layer_idx)](x)
|
| 167 |
+
if layer_idx in self.upsample_layer_indices:
|
| 168 |
+
x = self.upsample_blocks[self.upsample_layer_indices.index(layer_idx)](x)
|
| 169 |
+
if layer_idx in self.interpolation_layer_indices:
|
| 170 |
+
x = self.interpolation_blocks[self.interpolation_layer_indices.index(layer_idx)](x, target_len=kwargs['target_len'])
|
| 171 |
+
x = block(x)
|
| 172 |
+
x = self.output_projection(x)
|
| 173 |
+
return x
|
| 174 |
+
|
| 175 |
+
def setup_caches(self, *args, **kwargs):
|
| 176 |
+
pass
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class ConvNeXtV2Block(nn.Module):
|
| 180 |
+
def __init__(
|
| 181 |
+
self,
|
| 182 |
+
dim: int,
|
| 183 |
+
intermediate_dim: int,
|
| 184 |
+
dilation: int = 1,
|
| 185 |
+
):
|
| 186 |
+
super().__init__()
|
| 187 |
+
padding = (dilation * (7 - 1)) // 2
|
| 188 |
+
self.dwconv = nn.Conv1d(
|
| 189 |
+
dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
|
| 190 |
+
) # depthwise conv
|
| 191 |
+
self.norm = ConvNextV2LayerNorm(dim, data_format="channels_first")
|
| 192 |
+
self.pwconv1 = nn.Linear(
|
| 193 |
+
dim, intermediate_dim
|
| 194 |
+
) # pointwise/1x1 convs, implemented with linear layers
|
| 195 |
+
self.act = nn.GELU()
|
| 196 |
+
self.grn = GRN(intermediate_dim)
|
| 197 |
+
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
| 198 |
+
|
| 199 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 200 |
+
residual = x
|
| 201 |
+
x = self.dwconv(x)
|
| 202 |
+
x = self.norm(x)
|
| 203 |
+
x = x.transpose(1, 2) # b d n -> b n d
|
| 204 |
+
x = self.pwconv1(x)
|
| 205 |
+
x = self.act(x)
|
| 206 |
+
x = self.grn(x)
|
| 207 |
+
x = self.pwconv2(x)
|
| 208 |
+
x = x.transpose(1, 2) # b n d -> b d n
|
| 209 |
+
return residual + x
|
modules/astral_quantization/default_model.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import AutoTokenizer, AutoModel, Wav2Vec2FeatureExtractor
|
| 3 |
+
|
| 4 |
+
class AstralQuantizer(torch.nn.Module):
|
| 5 |
+
def __init__(
|
| 6 |
+
self,
|
| 7 |
+
tokenizer_name: str,
|
| 8 |
+
ssl_model_name: str,
|
| 9 |
+
ssl_output_layer: int,
|
| 10 |
+
encoder: torch.nn.Module,
|
| 11 |
+
quantizer: torch.nn.Module,
|
| 12 |
+
skip_ssl: bool = False,
|
| 13 |
+
):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.encoder = encoder
|
| 16 |
+
self.quantizer = quantizer
|
| 17 |
+
self.tokenizer_name = tokenizer_name
|
| 18 |
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
| 19 |
+
|
| 20 |
+
# Load SSL model from Huggingface
|
| 21 |
+
self.ssl_model_name = ssl_model_name
|
| 22 |
+
self.ssl_output_layer = ssl_output_layer
|
| 23 |
+
self.ssl_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(ssl_model_name)
|
| 24 |
+
|
| 25 |
+
if skip_ssl: # in case the same SSL model has been loaded somewhere else
|
| 26 |
+
self.ssl_model = None
|
| 27 |
+
else:
|
| 28 |
+
self.ssl_model = AutoModel.from_pretrained(ssl_model_name).eval()
|
| 29 |
+
self.ssl_model.encoder.layers = self.ssl_model.encoder.layers[:ssl_output_layer]
|
| 30 |
+
self.ssl_model.encoder.layer_norm = torch.nn.Identity()
|
| 31 |
+
|
| 32 |
+
def load_separate_checkpoint(self, checkpoint_path):
|
| 33 |
+
params = torch.load(checkpoint_path, map_location='cpu')['net']
|
| 34 |
+
for key in params.keys():
|
| 35 |
+
for k in list(params[key].keys()):
|
| 36 |
+
if k.startswith("module."):
|
| 37 |
+
params[key][k[len("module."):]] = params[key][k]
|
| 38 |
+
del params[key][k]
|
| 39 |
+
self.encoder.load_state_dict(params['encoder'])
|
| 40 |
+
self.quantizer.load_state_dict(params['vq'])
|
| 41 |
+
if self.decoder is not None:
|
| 42 |
+
self.decoder.load_state_dict(params['decoder'])
|
| 43 |
+
if self.asr_decoder is not None:
|
| 44 |
+
self.asr_decoder.load_state_dict(params['predictor'], strict=False)
|
| 45 |
+
|
| 46 |
+
def forward(self, waves_16k, wave_16k_lens, ssl_model=None):
|
| 47 |
+
ssl_fn = self.ssl_model if self.ssl_model else ssl_model
|
| 48 |
+
assert ssl_fn is not None, "In case in-class SSL model loading is skipped, external ssl_model must be provided"
|
| 49 |
+
waves_16k_input_list = [
|
| 50 |
+
waves_16k[bib, :wave_16k_lens[bib]].cpu().numpy()
|
| 51 |
+
for bib in range(len(waves_16k))
|
| 52 |
+
]
|
| 53 |
+
alt_inputs = self.ssl_feature_extractor(
|
| 54 |
+
waves_16k_input_list,
|
| 55 |
+
return_tensors='pt',
|
| 56 |
+
return_attention_mask=True,
|
| 57 |
+
padding=True,
|
| 58 |
+
sampling_rate=16000
|
| 59 |
+
).to(waves_16k.device)
|
| 60 |
+
feature_lens = alt_inputs.data['attention_mask'].sum(-1) // 320 # frame rate of hubert is 50 Hz
|
| 61 |
+
|
| 62 |
+
outputs = ssl_fn(
|
| 63 |
+
alt_inputs.input_values,
|
| 64 |
+
attention_mask=alt_inputs.attention_mask,
|
| 65 |
+
)
|
| 66 |
+
last_hidden_states = outputs.last_hidden_state
|
| 67 |
+
last_hidden_states = last_hidden_states[:, :feature_lens.max(), :]
|
| 68 |
+
feature_lens = feature_lens.clamp(max=last_hidden_states.size(1))
|
| 69 |
+
last_hidden_states = last_hidden_states.transpose(1, 2)
|
| 70 |
+
x_hidden = self.encoder(last_hidden_states, feature_lens)
|
| 71 |
+
x_hidden = x_hidden.transpose(1, 2)
|
| 72 |
+
x_quantized, indices = self.quantizer(x_hidden)[:2]
|
| 73 |
+
return x_quantized, indices, feature_lens
|
modules/astral_quantization/transformer.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from torch import Tensor
|
| 12 |
+
from torch.nn import functional as F
|
| 13 |
+
import time
|
| 14 |
+
|
| 15 |
+
def find_multiple(n: int, k: int) -> int:
|
| 16 |
+
if n % k == 0:
|
| 17 |
+
return n
|
| 18 |
+
return n + k - (n % k)
|
| 19 |
+
|
| 20 |
+
class AdaptiveLayerNorm(nn.Module):
|
| 21 |
+
r"""Adaptive Layer Normalization"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, d_model, norm) -> None:
|
| 24 |
+
super(AdaptiveLayerNorm, self).__init__()
|
| 25 |
+
self.project_layer = nn.Linear(d_model, 2 * d_model)
|
| 26 |
+
self.norm = norm
|
| 27 |
+
self.d_model = d_model
|
| 28 |
+
self.eps = self.norm.eps
|
| 29 |
+
|
| 30 |
+
def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
|
| 31 |
+
if embedding is None:
|
| 32 |
+
return self.norm(input)
|
| 33 |
+
weight, bias = torch.split(
|
| 34 |
+
self.project_layer(embedding),
|
| 35 |
+
split_size_or_sections=self.d_model,
|
| 36 |
+
dim=-1,
|
| 37 |
+
)
|
| 38 |
+
return weight * self.norm(input) + bias
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@dataclass
|
| 42 |
+
class ModelArgs:
|
| 43 |
+
block_size: int = 2048
|
| 44 |
+
vocab_size: int = 32000
|
| 45 |
+
n_layer: int = 32
|
| 46 |
+
n_head: int = 32
|
| 47 |
+
dim: int = 4096
|
| 48 |
+
intermediate_size: int = None
|
| 49 |
+
n_local_heads: int = -1
|
| 50 |
+
head_dim: int = 64
|
| 51 |
+
rope_base: float = 10000
|
| 52 |
+
norm_eps: float = 1e-5
|
| 53 |
+
has_cross_attention: bool = False
|
| 54 |
+
context_dim: int = 0
|
| 55 |
+
is_causal: bool = False
|
| 56 |
+
dropout_rate: float = 0.1
|
| 57 |
+
attn_dropout_rate: float = 0.1
|
| 58 |
+
|
| 59 |
+
def __post_init__(self):
|
| 60 |
+
if self.n_local_heads == -1:
|
| 61 |
+
self.n_local_heads = self.n_head
|
| 62 |
+
if self.intermediate_size is None:
|
| 63 |
+
hidden_dim = 4 * self.dim
|
| 64 |
+
n_hidden = int(2 * hidden_dim / 3)
|
| 65 |
+
self.intermediate_size = find_multiple(n_hidden, 256)
|
| 66 |
+
# self.head_dim = self.dim // self.n_head
|
| 67 |
+
|
| 68 |
+
class Transformer(nn.Module):
|
| 69 |
+
def __init__(self, config: ModelArgs) -> None:
|
| 70 |
+
super().__init__()
|
| 71 |
+
self.config = config
|
| 72 |
+
|
| 73 |
+
self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer))
|
| 74 |
+
self.norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
|
| 75 |
+
|
| 76 |
+
self.max_batch_size = -1
|
| 77 |
+
self.max_seq_length = config.block_size
|
| 78 |
+
freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.head_dim,
|
| 79 |
+
self.config.rope_base)
|
| 80 |
+
self.register_buffer("freqs_cis", freqs_cis)
|
| 81 |
+
|
| 82 |
+
causal_mask = torch.tril(
|
| 83 |
+
torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)
|
| 84 |
+
)
|
| 85 |
+
self.register_buffer("causal_mask", causal_mask)
|
| 86 |
+
|
| 87 |
+
def forward(self,
|
| 88 |
+
x: Tensor,
|
| 89 |
+
c: Tensor,
|
| 90 |
+
input_pos: Optional[Tensor] = None,
|
| 91 |
+
mask: Optional[Tensor] = None,
|
| 92 |
+
context: Optional[Tensor] = None,
|
| 93 |
+
context_input_pos: Optional[Tensor] = None,
|
| 94 |
+
cross_attention_mask: Optional[Tensor] = None,
|
| 95 |
+
) -> Tensor:
|
| 96 |
+
if mask is None:
|
| 97 |
+
mask = self.causal_mask[:x.size(1), :x.size(1)]
|
| 98 |
+
else:
|
| 99 |
+
mask = mask[..., input_pos]
|
| 100 |
+
freqs_cis = self.freqs_cis[input_pos]
|
| 101 |
+
if context is not None:
|
| 102 |
+
context_freqs_cis = self.freqs_cis[context_input_pos]
|
| 103 |
+
else:
|
| 104 |
+
context_freqs_cis = None
|
| 105 |
+
skip_in_x_list = []
|
| 106 |
+
for i, layer in enumerate(self.layers):
|
| 107 |
+
x = layer(x, c, freqs_cis, mask, context, context_freqs_cis, cross_attention_mask)
|
| 108 |
+
x = self.norm(x, c)
|
| 109 |
+
return x
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class TransformerBlock(nn.Module):
|
| 113 |
+
def __init__(self, config: ModelArgs) -> None:
|
| 114 |
+
super().__init__()
|
| 115 |
+
self.attention = Attention(config)
|
| 116 |
+
self.feed_forward = FeedForward(config)
|
| 117 |
+
self.ffn_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
|
| 118 |
+
self.attention_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
|
| 119 |
+
|
| 120 |
+
if config.has_cross_attention:
|
| 121 |
+
self.has_cross_attention = True
|
| 122 |
+
self.cross_attention = Attention(config, is_cross_attention=True)
|
| 123 |
+
self.cross_attention_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
|
| 124 |
+
else:
|
| 125 |
+
self.has_cross_attention = False
|
| 126 |
+
|
| 127 |
+
def forward(self,
|
| 128 |
+
x: Tensor,
|
| 129 |
+
c: Tensor,
|
| 130 |
+
freqs_cis: Tensor,
|
| 131 |
+
mask: Tensor,
|
| 132 |
+
context: Optional[Tensor] = None,
|
| 133 |
+
context_freqs_cis: Optional[Tensor] = None,
|
| 134 |
+
cross_attention_mask: Optional[Tensor] = None,
|
| 135 |
+
) -> Tensor:
|
| 136 |
+
#time_attn_start = time.time()
|
| 137 |
+
h = x + self.attention(self.attention_norm(x, c), freqs_cis, mask)
|
| 138 |
+
#print(f"time take for attention of sequence length {x.shape[1]} is {time.time() - time_attn_start}")
|
| 139 |
+
if self.has_cross_attention:
|
| 140 |
+
h = h + self.cross_attention(self.cross_attention_norm(h, c), freqs_cis, cross_attention_mask, context, context_freqs_cis)
|
| 141 |
+
out = h + self.feed_forward(self.ffn_norm(h, c))
|
| 142 |
+
return out
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class Attention(nn.Module):
|
| 146 |
+
def __init__(self, config: ModelArgs, is_cross_attention: bool = False):
|
| 147 |
+
super().__init__()
|
| 148 |
+
assert config.dim % config.n_head == 0
|
| 149 |
+
|
| 150 |
+
total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
|
| 151 |
+
# key, query, value projections for all heads, but in a batch
|
| 152 |
+
if is_cross_attention:
|
| 153 |
+
self.wq = nn.Linear(config.dim, config.n_head * config.head_dim, bias=False)
|
| 154 |
+
self.wkv = nn.Linear(config.context_dim, 2 * config.n_local_heads * config.head_dim, bias=False)
|
| 155 |
+
else:
|
| 156 |
+
self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
|
| 157 |
+
self.wo = nn.Linear(config.head_dim * config.n_head, config.dim, bias=False)
|
| 158 |
+
self.kv_cache = None
|
| 159 |
+
|
| 160 |
+
self.n_head = config.n_head
|
| 161 |
+
self.head_dim = config.head_dim
|
| 162 |
+
self.n_local_heads = config.n_local_heads
|
| 163 |
+
self.dim = config.dim
|
| 164 |
+
self.attn_dropout_rate = config.attn_dropout_rate
|
| 165 |
+
|
| 166 |
+
def forward(self,
|
| 167 |
+
x: Tensor,
|
| 168 |
+
freqs_cis: Tensor,
|
| 169 |
+
mask: Tensor,
|
| 170 |
+
context: Optional[Tensor] = None,
|
| 171 |
+
context_freqs_cis: Optional[Tensor] = None,
|
| 172 |
+
) -> Tensor:
|
| 173 |
+
bsz, seqlen, _ = x.shape
|
| 174 |
+
|
| 175 |
+
kv_size = self.n_local_heads * self.head_dim
|
| 176 |
+
if context is None:
|
| 177 |
+
q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1)
|
| 178 |
+
context_seqlen = seqlen
|
| 179 |
+
else:
|
| 180 |
+
q = self.wq(x)
|
| 181 |
+
k, v = self.wkv(context).split([kv_size, kv_size], dim=-1)
|
| 182 |
+
context_seqlen = context.shape[1]
|
| 183 |
+
|
| 184 |
+
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
|
| 185 |
+
k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
|
| 186 |
+
v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
|
| 187 |
+
|
| 188 |
+
q = apply_rotary_emb(q, freqs_cis)
|
| 189 |
+
k = apply_rotary_emb(k, context_freqs_cis if context_freqs_cis is not None else freqs_cis)
|
| 190 |
+
|
| 191 |
+
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
|
| 192 |
+
|
| 193 |
+
k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
| 194 |
+
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
| 195 |
+
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=self.attn_dropout_rate if self.training else 0.0)
|
| 196 |
+
|
| 197 |
+
y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.head_dim * self.n_head)
|
| 198 |
+
|
| 199 |
+
y = self.wo(y)
|
| 200 |
+
return y
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
class FeedForward(nn.Module):
|
| 204 |
+
def __init__(self, config: ModelArgs) -> None:
|
| 205 |
+
super().__init__()
|
| 206 |
+
self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
|
| 207 |
+
self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
|
| 208 |
+
self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
|
| 209 |
+
self.dropout = nn.Dropout(config.dropout_rate)
|
| 210 |
+
|
| 211 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 212 |
+
return self.w2(self.dropout(F.silu(self.w1(x)) * self.w3(x)))
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
class RMSNorm(nn.Module):
|
| 216 |
+
def __init__(self, dim: int, eps: float = 1e-5):
|
| 217 |
+
super().__init__()
|
| 218 |
+
self.eps = eps
|
| 219 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 220 |
+
|
| 221 |
+
def _norm(self, x):
|
| 222 |
+
return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
|
| 223 |
+
|
| 224 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 225 |
+
output = self._norm(x.float()).type_as(x)
|
| 226 |
+
return output * self.weight
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def precompute_freqs_cis(
|
| 230 |
+
seq_len: int, n_elem: int, base: int = 10000,
|
| 231 |
+
dtype: torch.dtype = torch.bfloat16
|
| 232 |
+
) -> Tensor:
|
| 233 |
+
freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
|
| 234 |
+
t = torch.arange(seq_len, device=freqs.device)
|
| 235 |
+
freqs = torch.outer(t, freqs)
|
| 236 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
| 237 |
+
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
|
| 238 |
+
return cache.to(dtype=dtype)
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
|
| 242 |
+
xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
|
| 243 |
+
freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
|
| 244 |
+
x_out2 = torch.stack(
|
| 245 |
+
[
|
| 246 |
+
xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
|
| 247 |
+
xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
|
| 248 |
+
],
|
| 249 |
+
-1,
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
x_out2 = x_out2.flatten(3)
|
| 253 |
+
return x_out2.type_as(x)
|
| 254 |
+
|
modules/audio.py
CHANGED
|
@@ -1,82 +1,82 @@
|
|
| 1 |
-
import numpy as np
|
| 2 |
-
import torch
|
| 3 |
-
import torch.utils.data
|
| 4 |
-
from librosa.filters import mel as librosa_mel_fn
|
| 5 |
-
from scipy.io.wavfile import read
|
| 6 |
-
|
| 7 |
-
MAX_WAV_VALUE = 32768.0
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
def load_wav(full_path):
|
| 11 |
-
sampling_rate, data = read(full_path)
|
| 12 |
-
return data, sampling_rate
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
| 16 |
-
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
def dynamic_range_decompression(x, C=1):
|
| 20 |
-
return np.exp(x) / C
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
| 24 |
-
return torch.log(torch.clamp(x, min=clip_val) * C)
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
def dynamic_range_decompression_torch(x, C=1):
|
| 28 |
-
return torch.exp(x) / C
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
def spectral_normalize_torch(magnitudes):
|
| 32 |
-
output = dynamic_range_compression_torch(magnitudes)
|
| 33 |
-
return output
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
def spectral_de_normalize_torch(magnitudes):
|
| 37 |
-
output = dynamic_range_decompression_torch(magnitudes)
|
| 38 |
-
return output
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
mel_basis = {}
|
| 42 |
-
hann_window = {}
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
| 46 |
-
if torch.min(y) < -1.0:
|
| 47 |
-
print("min value is ", torch.min(y))
|
| 48 |
-
if torch.max(y) > 1.0:
|
| 49 |
-
print("max value is ", torch.max(y))
|
| 50 |
-
|
| 51 |
-
global mel_basis, hann_window # pylint: disable=global-statement
|
| 52 |
-
if f"{str(sampling_rate)}_{str(fmax)}_{str(y.device)}" not in mel_basis:
|
| 53 |
-
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
| 54 |
-
mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
| 55 |
-
hann_window[str(sampling_rate) + "_" + str(y.device)] = torch.hann_window(win_size).to(y.device)
|
| 56 |
-
|
| 57 |
-
y = torch.nn.functional.pad(
|
| 58 |
-
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
|
| 59 |
-
)
|
| 60 |
-
y = y.squeeze(1)
|
| 61 |
-
|
| 62 |
-
spec = torch.view_as_real(
|
| 63 |
-
torch.stft(
|
| 64 |
-
y,
|
| 65 |
-
n_fft,
|
| 66 |
-
hop_length=hop_size,
|
| 67 |
-
win_length=win_size,
|
| 68 |
-
window=hann_window[str(sampling_rate) + "_" + str(y.device)],
|
| 69 |
-
center=center,
|
| 70 |
-
pad_mode="reflect",
|
| 71 |
-
normalized=False,
|
| 72 |
-
onesided=True,
|
| 73 |
-
return_complex=True,
|
| 74 |
-
)
|
| 75 |
-
)
|
| 76 |
-
|
| 77 |
-
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
| 78 |
-
|
| 79 |
-
spec = torch.matmul(mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)], spec)
|
| 80 |
-
spec = spectral_normalize_torch(spec)
|
| 81 |
-
|
| 82 |
-
return spec
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.utils.data
|
| 4 |
+
from librosa.filters import mel as librosa_mel_fn
|
| 5 |
+
from scipy.io.wavfile import read
|
| 6 |
+
|
| 7 |
+
MAX_WAV_VALUE = 32768.0
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def load_wav(full_path):
|
| 11 |
+
sampling_rate, data = read(full_path)
|
| 12 |
+
return data, sampling_rate
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
| 16 |
+
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def dynamic_range_decompression(x, C=1):
|
| 20 |
+
return np.exp(x) / C
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
| 24 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def dynamic_range_decompression_torch(x, C=1):
|
| 28 |
+
return torch.exp(x) / C
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def spectral_normalize_torch(magnitudes):
|
| 32 |
+
output = dynamic_range_compression_torch(magnitudes)
|
| 33 |
+
return output
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def spectral_de_normalize_torch(magnitudes):
|
| 37 |
+
output = dynamic_range_decompression_torch(magnitudes)
|
| 38 |
+
return output
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
mel_basis = {}
|
| 42 |
+
hann_window = {}
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
| 46 |
+
if torch.min(y) < -1.0:
|
| 47 |
+
print("min value is ", torch.min(y))
|
| 48 |
+
if torch.max(y) > 1.0:
|
| 49 |
+
print("max value is ", torch.max(y))
|
| 50 |
+
|
| 51 |
+
global mel_basis, hann_window # pylint: disable=global-statement
|
| 52 |
+
if f"{str(sampling_rate)}_{str(fmax)}_{str(y.device)}" not in mel_basis:
|
| 53 |
+
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
| 54 |
+
mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
| 55 |
+
hann_window[str(sampling_rate) + "_" + str(y.device)] = torch.hann_window(win_size).to(y.device)
|
| 56 |
+
|
| 57 |
+
y = torch.nn.functional.pad(
|
| 58 |
+
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
|
| 59 |
+
)
|
| 60 |
+
y = y.squeeze(1)
|
| 61 |
+
|
| 62 |
+
spec = torch.view_as_real(
|
| 63 |
+
torch.stft(
|
| 64 |
+
y,
|
| 65 |
+
n_fft,
|
| 66 |
+
hop_length=hop_size,
|
| 67 |
+
win_length=win_size,
|
| 68 |
+
window=hann_window[str(sampling_rate) + "_" + str(y.device)],
|
| 69 |
+
center=center,
|
| 70 |
+
pad_mode="reflect",
|
| 71 |
+
normalized=False,
|
| 72 |
+
onesided=True,
|
| 73 |
+
return_complex=True,
|
| 74 |
+
)
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
| 78 |
+
|
| 79 |
+
spec = torch.matmul(mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)], spec)
|
| 80 |
+
spec = spectral_normalize_torch(spec)
|
| 81 |
+
|
| 82 |
+
return spec
|
modules/bigvgan/__pycache__/activations.cpython-310.pyc
ADDED
|
Binary file (4 kB). View file
|
|
|
modules/bigvgan/__pycache__/bigvgan.cpython-310.pyc
ADDED
|
Binary file (11.8 kB). View file
|
|
|
modules/bigvgan/__pycache__/env.cpython-310.pyc
ADDED
|
Binary file (796 Bytes). View file
|
|
|
modules/bigvgan/__pycache__/meldataset.cpython-310.pyc
ADDED
|
Binary file (8.54 kB). View file
|
|
|
modules/bigvgan/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (2.84 kB). View file
|
|
|
modules/bigvgan/alias_free_activation/cuda/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (158 Bytes). View file
|
|
|
modules/bigvgan/alias_free_activation/cuda/__pycache__/activation1d.cpython-310.pyc
ADDED
|
Binary file (2.34 kB). View file
|
|
|
modules/bigvgan/alias_free_activation/cuda/__pycache__/load.cpython-310.pyc
ADDED
|
Binary file (1.99 kB). View file
|
|
|
modules/bigvgan/alias_free_activation/cuda/activation1d.py
CHANGED
|
@@ -3,10 +3,10 @@
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.nn as nn
|
| 6 |
-
from
|
| 7 |
|
| 8 |
# load fused CUDA kernel: this enables importing anti_alias_activation_cuda
|
| 9 |
-
from
|
| 10 |
|
| 11 |
anti_alias_activation_cuda = load.load()
|
| 12 |
|
|
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.nn as nn
|
| 6 |
+
from ..torch.resample import UpSample1d, DownSample1d
|
| 7 |
|
| 8 |
# load fused CUDA kernel: this enables importing anti_alias_activation_cuda
|
| 9 |
+
from ..cuda import load
|
| 10 |
|
| 11 |
anti_alias_activation_cuda = load.load()
|
| 12 |
|
modules/bigvgan/alias_free_activation/cuda/build/.ninja_deps
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e233713716a5778577f244b0f310944ff26d3079ce0e42491791da7d42e363c1
|
| 3 |
+
size 522068
|
modules/bigvgan/alias_free_activation/cuda/build/.ninja_log
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ninja log v5
|
| 2 |
+
9 39554 7516864785377831 anti_alias_activation.o 3a177f31dd72c43c
|
| 3 |
+
13 152601 7516865914203767 anti_alias_activation_cuda.cuda.o 2d613e7382d803fd
|
| 4 |
+
152628 153062 7516865920541751 anti_alias_activation_cuda.pyd f6366e9bdfb27f7
|
| 5 |
+
128 50503 7654004565901584 anti_alias_activation.o 9ed3213f2e0d0858
|
| 6 |
+
133 176837 7654005827401976 anti_alias_activation_cuda.cuda.o a679b6661c609136
|
| 7 |
+
176839 177401 7654005835005523 anti_alias_activation_cuda.pyd f6366e9bdfb27f7
|
modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation.o
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:74c2824b05582070b69f51ec588aadb268c4fddf18fbb4590f901d1cdf32185c
|
| 3 |
+
size 3246655
|
modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.cuda.o
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:86c48de557041de7ebaff7926b5f346cc5e4e2dddc6cf5b88409f6cb161db0f4
|
| 3 |
+
size 4724513
|
modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.exp
ADDED
|
Binary file (25.1 kB). View file
|
|
|
modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.lib
ADDED
|
Binary file (43.7 kB). View file
|
|
|
modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.pyd
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:db37ea2dd31dfe67e68ee6019877d14638c41724ff9342c55f638f4d2cda3d03
|
| 3 |
+
size 2454528
|
modules/bigvgan/alias_free_activation/cuda/build/build.ninja
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ninja_required_version = 1.3
|
| 2 |
+
cxx = cl
|
| 3 |
+
nvcc = C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\bin\nvcc
|
| 4 |
+
|
| 5 |
+
cflags = -DTORCH_EXTENSION_NAME=anti_alias_activation_cuda -DTORCH_API_INCLUDE_EXTENSION_H -ID:\Anaconda\envs\vocos\lib\site-packages\torch\include -ID:\Anaconda\envs\vocos\lib\site-packages\torch\include\torch\csrc\api\include "-IC:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\include" -ID:\Anaconda\envs\vocos\Include /std:c++17 -O3 /MD /wd4819 /wd4251 /wd4244 /wd4267 /wd4275 /wd4018 /wd4190 /wd4624 /wd4067 /wd4068 /EHsc
|
| 6 |
+
post_cflags =
|
| 7 |
+
cuda_cflags = -Xcudafe --diag_suppress=dll_interface_conflict_dllexport_assumed -Xcudafe --diag_suppress=dll_interface_conflict_none_assumed -Xcudafe --diag_suppress=field_without_dll_interface -Xcudafe --diag_suppress=base_class_has_different_dll_interface -Xcompiler /EHsc -Xcompiler /wd4068 -Xcompiler /wd4067 -Xcompiler /wd4624 -Xcompiler /wd4190 -Xcompiler /wd4018 -Xcompiler /wd4275 -Xcompiler /wd4267 -Xcompiler /wd4244 -Xcompiler /wd4251 -Xcompiler /wd4819 -Xcompiler /MD -DTORCH_EXTENSION_NAME=anti_alias_activation_cuda -DTORCH_API_INCLUDE_EXTENSION_H -ID:\Anaconda\envs\vocos\lib\site-packages\torch\include -ID:\Anaconda\envs\vocos\lib\site-packages\torch\include\torch\csrc\api\include "-IC:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\include" -ID:\Anaconda\envs\vocos\Include -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_86,code=sm_86 -std=c++17 -O3 -gencode arch=compute_70,code=sm_70 --use_fast_math -U__CUDA_NO_HALF_OPERATORS__ -U__CUDA_NO_HALF_CONVERSIONS__ --expt-relaxed-constexpr --expt-extended-lambda -gencode arch=compute_80,code=sm_80
|
| 8 |
+
cuda_post_cflags =
|
| 9 |
+
cuda_dlink_post_cflags =
|
| 10 |
+
sycl_dlink_post_cflags =
|
| 11 |
+
ldflags = /DLL c10.lib c10_cuda.lib torch_cpu.lib torch_cuda.lib -INCLUDE:?warp_size@cuda@at@@YAHXZ torch.lib /LIBPATH:D:\Anaconda\envs\vocos\lib\site-packages\torch\lib torch_python.lib /LIBPATH:D:\Anaconda\envs\vocos\libs "/LIBPATH:C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\lib\x64" cudart.lib
|
| 12 |
+
|
| 13 |
+
rule compile
|
| 14 |
+
command = cl /showIncludes $cflags -c $in /Fo$out $post_cflags
|
| 15 |
+
deps = msvc
|
| 16 |
+
|
| 17 |
+
rule cuda_compile
|
| 18 |
+
depfile = $out.d
|
| 19 |
+
deps = gcc
|
| 20 |
+
command = $nvcc --generate-dependencies-with-compile --dependency-output $out.d $cuda_cflags -c $in -o $out $cuda_post_cflags
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
rule link
|
| 27 |
+
command = "D$:\Visual Studio\VC\Tools\MSVC\14.29.30133\bin\Hostx86\x64/link.exe" $in /nologo $ldflags /out:$out
|
| 28 |
+
|
| 29 |
+
build anti_alias_activation.o: compile D$:\seed-vc\modules\bigvgan\alias_free_activation\cuda\anti_alias_activation.cpp
|
| 30 |
+
build anti_alias_activation_cuda.cuda.o: cuda_compile D$:\seed-vc\modules\bigvgan\alias_free_activation\cuda\anti_alias_activation_cuda.cu
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
build anti_alias_activation_cuda.pyd: link anti_alias_activation.o anti_alias_activation_cuda.cuda.o
|
| 37 |
+
|
| 38 |
+
default anti_alias_activation_cuda.pyd
|
modules/bigvgan/alias_free_activation/torch/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (217 Bytes). View file
|
|
|
modules/bigvgan/alias_free_activation/torch/__pycache__/act.cpython-310.pyc
ADDED
|
Binary file (1.05 kB). View file
|
|
|