Spaces:
Runtime error
Runtime error
Upload 201 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +12 -11
- .gitignore +1 -24
- .pre-commit-config.yaml +2 -2
- Data/Azusa/config.json +108 -0
- Data/Azusa/models/G_11300.pth +3 -0
- app.py +196 -177
- bert_gen.py +24 -17
- clap_gen.py +1 -1
- configs/config.json +771 -767
- data_utils.py +7 -23
- default_config.yml +3 -3
- export_onnx.py +6 -4
- for_deploy/infer.py +386 -0
- for_deploy/infer_utils.py +111 -0
- for_deploy/webui.py +556 -0
- infer.py +118 -88
- losses.py +95 -0
- models.py +66 -65
- oldVersion/V210/__init__.py +9 -4
- oldVersion/V210/models.py +1 -1
- oldVersion/V210/text/__init__.py +4 -2
- oldVersion/V210/text/chinese_bert.py +21 -2
- oldVersion/V210/text/english_bert_mock.py +21 -2
- oldVersion/V210/text/japanese_bert.py +23 -2
- onnx_infer.py +68 -0
- onnx_modules/V200/__init__.py +4 -0
- onnx_modules/V200_OnnxInference/__init__.py +126 -0
- onnx_modules/V210/__init__.py +4 -0
- onnx_modules/V210/models_onnx.py +1 -1
- onnx_modules/V210_OnnxInference/__init__.py +129 -0
- onnx_modules/V220/__init__.py +4 -0
- onnx_modules/V220/attentions_onnx.py +378 -0
- onnx_modules/V220/models_onnx.py +1076 -0
- onnx_modules/V220/text/__init__.py +1 -0
- onnx_modules/V220/text/symbols.py +187 -0
- onnx_modules/V220_OnnxInference/__init__.py +128 -0
- onnx_modules/V220_novq_dev/__init__.py +4 -0
- onnx_modules/V220_novq_dev/attentions_onnx.py +378 -0
- onnx_modules/V220_novq_dev/models_onnx.py +1048 -0
- onnx_modules/V220_novq_dev/text/__init__.py +1 -0
- onnx_modules/V220_novq_dev/text/symbols.py +187 -0
- onnx_modules/V230/__init__.py +4 -0
- onnx_modules/V230/attentions_onnx.py +378 -0
- onnx_modules/V230/models_onnx.py +1061 -0
- onnx_modules/V230/text/__init__.py +1 -0
- onnx_modules/V230/text/symbols.py +187 -0
- onnx_modules/V230_OnnxInference/__init__.py +126 -0
- onnx_modules/__init__.py +12 -4
- re_matching.py +0 -1
- requirements.txt +2 -3
.gitattributes
CHANGED
|
@@ -1,35 +1,36 @@
|
|
| 1 |
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 4 |
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bin.* filter=lfs diff=lfs merge=lfs -text
|
| 5 |
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 6 |
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 11 |
*.model filter=lfs diff=lfs merge=lfs -text
|
| 12 |
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
| 13 |
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 14 |
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 15 |
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 16 |
*.pb filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
| 17 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 18 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 19 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 20 |
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 21 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 22 |
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 23 |
*.tgz filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 24 |
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.db* filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.ark* filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
**/*ckpt*data* filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
**/*ckpt*.meta filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
**/*ckpt*.index filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -textoldVersion/V200/text/cmudict_cache.pickle filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
oldVersion/V210/text/cmudict_cache.pickle filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
text/cmudict_cache.pickle filter=lfs diff=lfs merge=lfs -text
|
.gitignore
CHANGED
|
@@ -159,27 +159,4 @@ cython_debug/
|
|
| 159 |
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 160 |
#.idea/
|
| 161 |
|
| 162 |
-
.DS_Store
|
| 163 |
-
/models
|
| 164 |
-
/logs
|
| 165 |
-
|
| 166 |
-
filelists/*
|
| 167 |
-
!/filelists/esd.list
|
| 168 |
-
data/*
|
| 169 |
-
/*.yml
|
| 170 |
-
!/default_config.yml
|
| 171 |
-
/Web/
|
| 172 |
-
/emotional/*/*.bin
|
| 173 |
-
/bert/*/*.bin
|
| 174 |
-
/bert/*/*.h5
|
| 175 |
-
/bert/*/*.model
|
| 176 |
-
/bert/*/*.safetensors
|
| 177 |
-
/bert/*/*.msgpack
|
| 178 |
-
asr_transcript.py
|
| 179 |
-
extract_list.py
|
| 180 |
-
dataset
|
| 181 |
-
/Data
|
| 182 |
-
Model
|
| 183 |
-
raw/
|
| 184 |
-
logs/
|
| 185 |
-
Data/*
|
|
|
|
| 159 |
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 160 |
#.idea/
|
| 161 |
|
| 162 |
+
.DS_Store
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.pre-commit-config.yaml
CHANGED
|
@@ -7,13 +7,13 @@ repos:
|
|
| 7 |
- id: trailing-whitespace
|
| 8 |
|
| 9 |
- repo: https://github.com/astral-sh/ruff-pre-commit
|
| 10 |
-
rev: v0.1.
|
| 11 |
hooks:
|
| 12 |
- id: ruff
|
| 13 |
args: [ --fix ]
|
| 14 |
|
| 15 |
- repo: https://github.com/psf/black
|
| 16 |
-
rev: 23.
|
| 17 |
hooks:
|
| 18 |
- id: black
|
| 19 |
|
|
|
|
| 7 |
- id: trailing-whitespace
|
| 8 |
|
| 9 |
- repo: https://github.com/astral-sh/ruff-pre-commit
|
| 10 |
+
rev: v0.1.8
|
| 11 |
hooks:
|
| 12 |
- id: ruff
|
| 13 |
args: [ --fix ]
|
| 14 |
|
| 15 |
- repo: https://github.com/psf/black
|
| 16 |
+
rev: 23.12.0
|
| 17 |
hooks:
|
| 18 |
- id: black
|
| 19 |
|
Data/Azusa/config.json
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"train": {
|
| 3 |
+
"log_interval": 100,
|
| 4 |
+
"eval_interval": 100,
|
| 5 |
+
"seed": 42,
|
| 6 |
+
"epochs": 1000,
|
| 7 |
+
"learning_rate": 0.0001,
|
| 8 |
+
"betas": [
|
| 9 |
+
0.8,
|
| 10 |
+
0.99
|
| 11 |
+
],
|
| 12 |
+
"eps": 1e-09,
|
| 13 |
+
"batch_size": 12,
|
| 14 |
+
"bf16_run": false,
|
| 15 |
+
"lr_decay": 0.99995,
|
| 16 |
+
"segment_size": 16384,
|
| 17 |
+
"init_lr_ratio": 1,
|
| 18 |
+
"warmup_epochs": 0,
|
| 19 |
+
"c_mel": 45,
|
| 20 |
+
"c_kl": 1.0,
|
| 21 |
+
"c_commit": 100,
|
| 22 |
+
"skip_optimizer": true,
|
| 23 |
+
"freeze_ZH_bert": false,
|
| 24 |
+
"freeze_JP_bert": false,
|
| 25 |
+
"freeze_EN_bert": false,
|
| 26 |
+
"freeze_emo": false
|
| 27 |
+
},
|
| 28 |
+
"data": {
|
| 29 |
+
"training_files": "Data/Azusa/filelists/train.list",
|
| 30 |
+
"validation_files": "Data/Azusa/filelists/val.list",
|
| 31 |
+
"max_wav_value": 32768.0,
|
| 32 |
+
"sampling_rate": 44100,
|
| 33 |
+
"filter_length": 2048,
|
| 34 |
+
"hop_length": 512,
|
| 35 |
+
"win_length": 2048,
|
| 36 |
+
"n_mel_channels": 128,
|
| 37 |
+
"mel_fmin": 0.0,
|
| 38 |
+
"mel_fmax": null,
|
| 39 |
+
"add_blank": true,
|
| 40 |
+
"n_speakers": 1,
|
| 41 |
+
"cleaned_text": true,
|
| 42 |
+
"spk2id": {
|
| 43 |
+
"Azusa": 0
|
| 44 |
+
}
|
| 45 |
+
},
|
| 46 |
+
"model": {
|
| 47 |
+
"use_spk_conditioned_encoder": true,
|
| 48 |
+
"use_noise_scaled_mas": true,
|
| 49 |
+
"use_mel_posterior_encoder": false,
|
| 50 |
+
"use_duration_discriminator": true,
|
| 51 |
+
"inter_channels": 192,
|
| 52 |
+
"hidden_channels": 192,
|
| 53 |
+
"filter_channels": 768,
|
| 54 |
+
"n_heads": 2,
|
| 55 |
+
"n_layers": 6,
|
| 56 |
+
"kernel_size": 3,
|
| 57 |
+
"p_dropout": 0.1,
|
| 58 |
+
"resblock": "1",
|
| 59 |
+
"resblock_kernel_sizes": [
|
| 60 |
+
3,
|
| 61 |
+
7,
|
| 62 |
+
11
|
| 63 |
+
],
|
| 64 |
+
"resblock_dilation_sizes": [
|
| 65 |
+
[
|
| 66 |
+
1,
|
| 67 |
+
3,
|
| 68 |
+
5
|
| 69 |
+
],
|
| 70 |
+
[
|
| 71 |
+
1,
|
| 72 |
+
3,
|
| 73 |
+
5
|
| 74 |
+
],
|
| 75 |
+
[
|
| 76 |
+
1,
|
| 77 |
+
3,
|
| 78 |
+
5
|
| 79 |
+
]
|
| 80 |
+
],
|
| 81 |
+
"upsample_rates": [
|
| 82 |
+
8,
|
| 83 |
+
8,
|
| 84 |
+
2,
|
| 85 |
+
2,
|
| 86 |
+
2
|
| 87 |
+
],
|
| 88 |
+
"upsample_initial_channel": 512,
|
| 89 |
+
"upsample_kernel_sizes": [
|
| 90 |
+
16,
|
| 91 |
+
16,
|
| 92 |
+
8,
|
| 93 |
+
2,
|
| 94 |
+
2
|
| 95 |
+
],
|
| 96 |
+
"n_layers_q": 3,
|
| 97 |
+
"use_spectral_norm": false,
|
| 98 |
+
"gin_channels": 512,
|
| 99 |
+
"slm": {
|
| 100 |
+
"model": "./slm/wavlm-base-plus",
|
| 101 |
+
"sr": 16000,
|
| 102 |
+
"hidden": 768,
|
| 103 |
+
"nlayers": 13,
|
| 104 |
+
"initial_channel": 64
|
| 105 |
+
}
|
| 106 |
+
},
|
| 107 |
+
"version": "2.3"
|
| 108 |
+
}
|
Data/Azusa/models/G_11300.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0708043b54ab21eb8ec1b600982ea7b105bcded370a9207281e043c64e195dc3
|
| 3 |
+
size 728379830
|
app.py
CHANGED
|
@@ -16,6 +16,10 @@ logging.basicConfig(
|
|
| 16 |
logger = logging.getLogger(__name__)
|
| 17 |
|
| 18 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
import utils
|
| 20 |
from infer import infer, latest_version, get_net_g, infer_multilang
|
| 21 |
import gradio as gr
|
|
@@ -42,6 +46,8 @@ def generate_audio(
|
|
| 42 |
language,
|
| 43 |
reference_audio,
|
| 44 |
emotion,
|
|
|
|
|
|
|
| 45 |
skip_start=False,
|
| 46 |
skip_end=False,
|
| 47 |
):
|
|
@@ -49,8 +55,8 @@ def generate_audio(
|
|
| 49 |
# silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
|
| 50 |
with torch.no_grad():
|
| 51 |
for idx, piece in enumerate(slices):
|
| 52 |
-
skip_start =
|
| 53 |
-
skip_end =
|
| 54 |
audio = infer(
|
| 55 |
piece,
|
| 56 |
reference_audio=reference_audio,
|
|
@@ -66,10 +72,11 @@ def generate_audio(
|
|
| 66 |
device=device,
|
| 67 |
skip_start=skip_start,
|
| 68 |
skip_end=skip_end,
|
|
|
|
|
|
|
| 69 |
)
|
| 70 |
audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
|
| 71 |
audio_list.append(audio16bit)
|
| 72 |
-
# audio_list.append(silence) # 将静音添加到列表中
|
| 73 |
return audio_list
|
| 74 |
|
| 75 |
|
|
@@ -90,8 +97,8 @@ def generate_audio_multilang(
|
|
| 90 |
# silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
|
| 91 |
with torch.no_grad():
|
| 92 |
for idx, piece in enumerate(slices):
|
| 93 |
-
skip_start =
|
| 94 |
-
skip_end =
|
| 95 |
audio = infer_multilang(
|
| 96 |
piece,
|
| 97 |
reference_audio=reference_audio,
|
|
@@ -110,7 +117,6 @@ def generate_audio_multilang(
|
|
| 110 |
)
|
| 111 |
audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
|
| 112 |
audio_list.append(audio16bit)
|
| 113 |
-
# audio_list.append(silence) # 将静音添加到列表中
|
| 114 |
return audio_list
|
| 115 |
|
| 116 |
|
|
@@ -127,63 +133,50 @@ def tts_split(
|
|
| 127 |
interval_between_sent,
|
| 128 |
reference_audio,
|
| 129 |
emotion,
|
|
|
|
|
|
|
| 130 |
):
|
| 131 |
-
if language == "mix":
|
| 132 |
-
return ("invalid", None)
|
| 133 |
while text.find("\n\n") != -1:
|
| 134 |
text = text.replace("\n\n", "\n")
|
|
|
|
| 135 |
para_list = re_matching.cut_para(text)
|
|
|
|
| 136 |
audio_list = []
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
skip_end = idx != len(para_list) - 1
|
| 141 |
-
audio = infer(
|
| 142 |
p,
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
device=device,
|
| 154 |
-
skip_start=skip_start,
|
| 155 |
-
skip_end=skip_end,
|
| 156 |
)
|
| 157 |
-
audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
|
| 158 |
-
audio_list.append(audio16bit)
|
| 159 |
silence = np.zeros((int)(44100 * interval_between_para), dtype=np.int16)
|
| 160 |
audio_list.append(silence)
|
| 161 |
-
|
| 162 |
-
for idx, p in enumerate(para_list):
|
| 163 |
-
skip_start = idx != 0
|
| 164 |
-
skip_end = idx != len(para_list) - 1
|
| 165 |
audio_list_sent = []
|
| 166 |
sent_list = re_matching.cut_sent(p)
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
audio = infer(
|
| 171 |
s,
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
device=device,
|
| 183 |
-
skip_start=skip_start,
|
| 184 |
-
skip_end=skip_end,
|
| 185 |
)
|
| 186 |
-
audio_list_sent.append(audio)
|
| 187 |
silence = np.zeros((int)(44100 * interval_between_sent))
|
| 188 |
audio_list_sent.append(silence)
|
| 189 |
if (interval_between_para - interval_between_sent) > 0:
|
|
@@ -196,10 +189,47 @@ def tts_split(
|
|
| 196 |
) # 对完整句子做音量归一
|
| 197 |
audio_list.append(audio16bit)
|
| 198 |
audio_concat = np.concatenate(audio_list)
|
| 199 |
-
return ("Success", (
|
| 200 |
|
| 201 |
|
| 202 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
text: str,
|
| 204 |
speaker,
|
| 205 |
sdp_ratio,
|
|
@@ -209,15 +239,9 @@ def tts_fn(
|
|
| 209 |
language,
|
| 210 |
reference_audio,
|
| 211 |
emotion,
|
| 212 |
-
|
|
|
|
| 213 |
):
|
| 214 |
-
if prompt_mode == "Audio prompt":
|
| 215 |
-
if reference_audio == None:
|
| 216 |
-
return ("Invalid audio prompt", None)
|
| 217 |
-
else:
|
| 218 |
-
reference_audio = load_audio(reference_audio)[1]
|
| 219 |
-
else:
|
| 220 |
-
reference_audio = None
|
| 221 |
audio_list = []
|
| 222 |
if language == "mix":
|
| 223 |
bool_valid, str_valid = re_matching.validate_text(text)
|
|
@@ -226,120 +250,40 @@ def tts_fn(
|
|
| 226 |
hps.data.sampling_rate,
|
| 227 |
np.concatenate([np.zeros(hps.data.sampling_rate // 2)]),
|
| 228 |
)
|
| 229 |
-
result = []
|
| 230 |
for slice in re_matching.text_matching(text):
|
| 231 |
-
_speaker = slice
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
temp_lang += temp_
|
| 247 |
-
else:
|
| 248 |
-
if len(temp_contant) == 0:
|
| 249 |
-
temp_contant.append([])
|
| 250 |
-
temp_lang.append([])
|
| 251 |
-
temp_contant[-1].append(content)
|
| 252 |
-
temp_lang[-1].append(lang)
|
| 253 |
-
for i, j in zip(temp_lang, temp_contant):
|
| 254 |
-
result.append([*zip(i, j), _speaker])
|
| 255 |
-
for i, one in enumerate(result):
|
| 256 |
-
skip_start = i != 0
|
| 257 |
-
skip_end = i != len(result) - 1
|
| 258 |
-
_speaker = one.pop()
|
| 259 |
-
idx = 0
|
| 260 |
-
while idx < len(one):
|
| 261 |
-
text_to_generate = []
|
| 262 |
-
lang_to_generate = []
|
| 263 |
-
while True:
|
| 264 |
-
lang, content = one[idx]
|
| 265 |
-
temp_text = [content]
|
| 266 |
-
if len(text_to_generate) > 0:
|
| 267 |
-
text_to_generate[-1] += [temp_text.pop(0)]
|
| 268 |
-
lang_to_generate[-1] += [lang]
|
| 269 |
-
if len(temp_text) > 0:
|
| 270 |
-
text_to_generate += [[i] for i in temp_text]
|
| 271 |
-
lang_to_generate += [[lang]] * len(temp_text)
|
| 272 |
-
if idx + 1 < len(one):
|
| 273 |
-
idx += 1
|
| 274 |
-
else:
|
| 275 |
-
break
|
| 276 |
-
skip_start = (idx != 0) and skip_start
|
| 277 |
-
skip_end = (idx != len(one) - 1) and skip_end
|
| 278 |
-
print(text_to_generate, lang_to_generate)
|
| 279 |
-
audio_list.extend(
|
| 280 |
-
generate_audio_multilang(
|
| 281 |
-
text_to_generate,
|
| 282 |
-
sdp_ratio,
|
| 283 |
-
noise_scale,
|
| 284 |
-
noise_scale_w,
|
| 285 |
-
length_scale,
|
| 286 |
-
_speaker,
|
| 287 |
-
lang_to_generate,
|
| 288 |
-
reference_audio,
|
| 289 |
-
emotion,
|
| 290 |
-
skip_start,
|
| 291 |
-
skip_end,
|
| 292 |
-
)
|
| 293 |
)
|
| 294 |
-
|
| 295 |
elif language.lower() == "auto":
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
)
|
| 304 |
-
|
| 305 |
-
while idx < len(sentences_list):
|
| 306 |
-
text_to_generate = []
|
| 307 |
-
lang_to_generate = []
|
| 308 |
-
while True:
|
| 309 |
-
content, lang = sentences_list[idx]
|
| 310 |
-
temp_text = [content]
|
| 311 |
-
lang = lang.upper()
|
| 312 |
-
if lang == "JA":
|
| 313 |
-
lang = "JP"
|
| 314 |
-
if len(text_to_generate) > 0:
|
| 315 |
-
text_to_generate[-1] += [temp_text.pop(0)]
|
| 316 |
-
lang_to_generate[-1] += [lang]
|
| 317 |
-
if len(temp_text) > 0:
|
| 318 |
-
text_to_generate += [[i] for i in temp_text]
|
| 319 |
-
lang_to_generate += [[lang]] * len(temp_text)
|
| 320 |
-
if idx + 1 < len(sentences_list):
|
| 321 |
-
idx += 1
|
| 322 |
-
else:
|
| 323 |
-
break
|
| 324 |
-
skip_start = (idx != 0) and skip_start
|
| 325 |
-
skip_end = (idx != len(sentences_list) - 1) and skip_end
|
| 326 |
-
print(text_to_generate, lang_to_generate)
|
| 327 |
-
audio_list.extend(
|
| 328 |
-
generate_audio_multilang(
|
| 329 |
-
text_to_generate,
|
| 330 |
-
sdp_ratio,
|
| 331 |
-
noise_scale,
|
| 332 |
-
noise_scale_w,
|
| 333 |
-
length_scale,
|
| 334 |
-
speaker,
|
| 335 |
-
lang_to_generate,
|
| 336 |
-
reference_audio,
|
| 337 |
-
emotion,
|
| 338 |
-
skip_start,
|
| 339 |
-
skip_end,
|
| 340 |
-
)
|
| 341 |
-
)
|
| 342 |
-
idx += 1
|
| 343 |
else:
|
| 344 |
audio_list.extend(
|
| 345 |
generate_audio(
|
|
@@ -352,13 +296,65 @@ def tts_fn(
|
|
| 352 |
language,
|
| 353 |
reference_audio,
|
| 354 |
emotion,
|
|
|
|
|
|
|
| 355 |
)
|
| 356 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 357 |
|
| 358 |
audio_concat = np.concatenate(audio_list)
|
| 359 |
return "Success", (hps.data.sampling_rate, audio_concat)
|
| 360 |
|
| 361 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 362 |
def load_audio(path):
|
| 363 |
audio, sr = librosa.load(path, 48000)
|
| 364 |
# audio = librosa.resample(audio, 44100, 48000)
|
|
@@ -394,10 +390,10 @@ if __name__ == "__main__":
|
|
| 394 |
with gr.Blocks() as app:
|
| 395 |
with gr.Row():
|
| 396 |
with gr.Column():
|
| 397 |
-
|
| 398 |
-
【AI
|
| 399 |
作者:Xz乔希 https://space.bilibili.com/5859321\n
|
| 400 |
-
|
| 401 |
【AI合集】https://www.modelscope.cn/studios/xzjosh/Bert-VITS2\n
|
| 402 |
Bert-VITS2项目:https://github.com/Stardust-minus/Bert-VITS2\n
|
| 403 |
使用本模型请严格遵守法律法规!\n
|
|
@@ -414,27 +410,31 @@ if __name__ == "__main__":
|
|
| 414 |
另外,所有的语言选项都可以用'|'分割长段实现分句生成。
|
| 415 |
""",
|
| 416 |
)
|
|
|
|
| 417 |
speaker = gr.Dropdown(
|
| 418 |
choices=speakers, value=speakers[0], label="Speaker"
|
| 419 |
)
|
| 420 |
_ = gr.Markdown(
|
| 421 |
-
value="提示模式(Prompt mode
|
|
|
|
| 422 |
)
|
| 423 |
prompt_mode = gr.Radio(
|
| 424 |
["Text prompt", "Audio prompt"],
|
| 425 |
label="Prompt Mode",
|
| 426 |
value="Text prompt",
|
|
|
|
| 427 |
)
|
| 428 |
text_prompt = gr.Textbox(
|
| 429 |
label="Text prompt",
|
| 430 |
-
placeholder="
|
| 431 |
-
|
|
|
|
| 432 |
)
|
| 433 |
audio_prompt = gr.Audio(
|
| 434 |
label="Audio prompt", type="filepath", visible=False
|
| 435 |
)
|
| 436 |
sdp_ratio = gr.Slider(
|
| 437 |
-
minimum=0, maximum=1, value=0.
|
| 438 |
)
|
| 439 |
noise_scale = gr.Slider(
|
| 440 |
minimum=0.1, maximum=2, value=0.5, step=0.01, label="Noise"
|
|
@@ -450,6 +450,21 @@ if __name__ == "__main__":
|
|
| 450 |
)
|
| 451 |
btn = gr.Button("点击生成", variant="primary")
|
| 452 |
with gr.Column():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 453 |
with gr.Row():
|
| 454 |
with gr.Column():
|
| 455 |
interval_between_sent = gr.Slider(
|
|
@@ -492,6 +507,8 @@ if __name__ == "__main__":
|
|
| 492 |
audio_prompt,
|
| 493 |
text_prompt,
|
| 494 |
prompt_mode,
|
|
|
|
|
|
|
| 495 |
],
|
| 496 |
outputs=[text_output, audio_output],
|
| 497 |
)
|
|
@@ -510,6 +527,8 @@ if __name__ == "__main__":
|
|
| 510 |
interval_between_sent,
|
| 511 |
audio_prompt,
|
| 512 |
text_prompt,
|
|
|
|
|
|
|
| 513 |
],
|
| 514 |
outputs=[text_output, audio_output],
|
| 515 |
)
|
|
|
|
| 16 |
logger = logging.getLogger(__name__)
|
| 17 |
|
| 18 |
import torch
|
| 19 |
+
import ssl
|
| 20 |
+
ssl._create_default_https_context = ssl._create_unverified_context
|
| 21 |
+
import nltk
|
| 22 |
+
nltk.download('cmudict')
|
| 23 |
import utils
|
| 24 |
from infer import infer, latest_version, get_net_g, infer_multilang
|
| 25 |
import gradio as gr
|
|
|
|
| 46 |
language,
|
| 47 |
reference_audio,
|
| 48 |
emotion,
|
| 49 |
+
style_text,
|
| 50 |
+
style_weight,
|
| 51 |
skip_start=False,
|
| 52 |
skip_end=False,
|
| 53 |
):
|
|
|
|
| 55 |
# silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
|
| 56 |
with torch.no_grad():
|
| 57 |
for idx, piece in enumerate(slices):
|
| 58 |
+
skip_start = idx != 0
|
| 59 |
+
skip_end = idx != len(slices) - 1
|
| 60 |
audio = infer(
|
| 61 |
piece,
|
| 62 |
reference_audio=reference_audio,
|
|
|
|
| 72 |
device=device,
|
| 73 |
skip_start=skip_start,
|
| 74 |
skip_end=skip_end,
|
| 75 |
+
style_text=style_text,
|
| 76 |
+
style_weight=style_weight,
|
| 77 |
)
|
| 78 |
audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
|
| 79 |
audio_list.append(audio16bit)
|
|
|
|
| 80 |
return audio_list
|
| 81 |
|
| 82 |
|
|
|
|
| 97 |
# silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
|
| 98 |
with torch.no_grad():
|
| 99 |
for idx, piece in enumerate(slices):
|
| 100 |
+
skip_start = idx != 0
|
| 101 |
+
skip_end = idx != len(slices) - 1
|
| 102 |
audio = infer_multilang(
|
| 103 |
piece,
|
| 104 |
reference_audio=reference_audio,
|
|
|
|
| 117 |
)
|
| 118 |
audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
|
| 119 |
audio_list.append(audio16bit)
|
|
|
|
| 120 |
return audio_list
|
| 121 |
|
| 122 |
|
|
|
|
| 133 |
interval_between_sent,
|
| 134 |
reference_audio,
|
| 135 |
emotion,
|
| 136 |
+
style_text,
|
| 137 |
+
style_weight,
|
| 138 |
):
|
|
|
|
|
|
|
| 139 |
while text.find("\n\n") != -1:
|
| 140 |
text = text.replace("\n\n", "\n")
|
| 141 |
+
text = text.replace("|", "")
|
| 142 |
para_list = re_matching.cut_para(text)
|
| 143 |
+
para_list = [p for p in para_list if p != ""]
|
| 144 |
audio_list = []
|
| 145 |
+
for p in para_list:
|
| 146 |
+
if not cut_by_sent:
|
| 147 |
+
audio_list += process_text(
|
|
|
|
|
|
|
| 148 |
p,
|
| 149 |
+
speaker,
|
| 150 |
+
sdp_ratio,
|
| 151 |
+
noise_scale,
|
| 152 |
+
noise_scale_w,
|
| 153 |
+
length_scale,
|
| 154 |
+
language,
|
| 155 |
+
reference_audio,
|
| 156 |
+
emotion,
|
| 157 |
+
style_text,
|
| 158 |
+
style_weight,
|
|
|
|
|
|
|
|
|
|
| 159 |
)
|
|
|
|
|
|
|
| 160 |
silence = np.zeros((int)(44100 * interval_between_para), dtype=np.int16)
|
| 161 |
audio_list.append(silence)
|
| 162 |
+
else:
|
|
|
|
|
|
|
|
|
|
| 163 |
audio_list_sent = []
|
| 164 |
sent_list = re_matching.cut_sent(p)
|
| 165 |
+
sent_list = [s for s in sent_list if s != ""]
|
| 166 |
+
for s in sent_list:
|
| 167 |
+
audio_list_sent += process_text(
|
|
|
|
| 168 |
s,
|
| 169 |
+
speaker,
|
| 170 |
+
sdp_ratio,
|
| 171 |
+
noise_scale,
|
| 172 |
+
noise_scale_w,
|
| 173 |
+
length_scale,
|
| 174 |
+
language,
|
| 175 |
+
reference_audio,
|
| 176 |
+
emotion,
|
| 177 |
+
style_text,
|
| 178 |
+
style_weight,
|
|
|
|
|
|
|
|
|
|
| 179 |
)
|
|
|
|
| 180 |
silence = np.zeros((int)(44100 * interval_between_sent))
|
| 181 |
audio_list_sent.append(silence)
|
| 182 |
if (interval_between_para - interval_between_sent) > 0:
|
|
|
|
| 189 |
) # 对完整句子做音量归一
|
| 190 |
audio_list.append(audio16bit)
|
| 191 |
audio_concat = np.concatenate(audio_list)
|
| 192 |
+
return ("Success", (hps.data.sampling_rate, audio_concat))
|
| 193 |
|
| 194 |
|
| 195 |
+
def process_mix(slice):
|
| 196 |
+
_speaker = slice.pop()
|
| 197 |
+
_text, _lang = [], []
|
| 198 |
+
for lang, content in slice:
|
| 199 |
+
content = content.split("|")
|
| 200 |
+
content = [part for part in content if part != ""]
|
| 201 |
+
if len(content) == 0:
|
| 202 |
+
continue
|
| 203 |
+
if len(_text) == 0:
|
| 204 |
+
_text = [[part] for part in content]
|
| 205 |
+
_lang = [[lang] for part in content]
|
| 206 |
+
else:
|
| 207 |
+
_text[-1].append(content[0])
|
| 208 |
+
_lang[-1].append(lang)
|
| 209 |
+
if len(content) > 1:
|
| 210 |
+
_text += [[part] for part in content[1:]]
|
| 211 |
+
_lang += [[lang] for part in content[1:]]
|
| 212 |
+
return _text, _lang, _speaker
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def process_auto(text):
|
| 216 |
+
_text, _lang = [], []
|
| 217 |
+
for slice in text.split("|"):
|
| 218 |
+
if slice == "":
|
| 219 |
+
continue
|
| 220 |
+
temp_text, temp_lang = [], []
|
| 221 |
+
sentences_list = split_by_language(slice, target_languages=["zh", "ja", "en"])
|
| 222 |
+
for sentence, lang in sentences_list:
|
| 223 |
+
if sentence == "":
|
| 224 |
+
continue
|
| 225 |
+
temp_text.append(sentence)
|
| 226 |
+
temp_lang.append(lang.upper())
|
| 227 |
+
_text.append(temp_text)
|
| 228 |
+
_lang.append(temp_lang)
|
| 229 |
+
return _text, _lang
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def process_text(
|
| 233 |
text: str,
|
| 234 |
speaker,
|
| 235 |
sdp_ratio,
|
|
|
|
| 239 |
language,
|
| 240 |
reference_audio,
|
| 241 |
emotion,
|
| 242 |
+
style_text=None,
|
| 243 |
+
style_weight=0,
|
| 244 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
audio_list = []
|
| 246 |
if language == "mix":
|
| 247 |
bool_valid, str_valid = re_matching.validate_text(text)
|
|
|
|
| 250 |
hps.data.sampling_rate,
|
| 251 |
np.concatenate([np.zeros(hps.data.sampling_rate // 2)]),
|
| 252 |
)
|
|
|
|
| 253 |
for slice in re_matching.text_matching(text):
|
| 254 |
+
_text, _lang, _speaker = process_mix(slice)
|
| 255 |
+
if _speaker is None:
|
| 256 |
+
continue
|
| 257 |
+
print(f"Text: {_text}\nLang: {_lang}")
|
| 258 |
+
audio_list.extend(
|
| 259 |
+
generate_audio_multilang(
|
| 260 |
+
_text,
|
| 261 |
+
sdp_ratio,
|
| 262 |
+
noise_scale,
|
| 263 |
+
noise_scale_w,
|
| 264 |
+
length_scale,
|
| 265 |
+
_speaker,
|
| 266 |
+
_lang,
|
| 267 |
+
reference_audio,
|
| 268 |
+
emotion,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
)
|
| 270 |
+
)
|
| 271 |
elif language.lower() == "auto":
|
| 272 |
+
_text, _lang = process_auto(text)
|
| 273 |
+
print(f"Text: {_text}\nLang: {_lang}")
|
| 274 |
+
audio_list.extend(
|
| 275 |
+
generate_audio_multilang(
|
| 276 |
+
_text,
|
| 277 |
+
sdp_ratio,
|
| 278 |
+
noise_scale,
|
| 279 |
+
noise_scale_w,
|
| 280 |
+
length_scale,
|
| 281 |
+
speaker,
|
| 282 |
+
_lang,
|
| 283 |
+
reference_audio,
|
| 284 |
+
emotion,
|
| 285 |
)
|
| 286 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 287 |
else:
|
| 288 |
audio_list.extend(
|
| 289 |
generate_audio(
|
|
|
|
| 296 |
language,
|
| 297 |
reference_audio,
|
| 298 |
emotion,
|
| 299 |
+
style_text,
|
| 300 |
+
style_weight,
|
| 301 |
)
|
| 302 |
)
|
| 303 |
+
return audio_list
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def tts_fn(
|
| 307 |
+
text: str,
|
| 308 |
+
speaker,
|
| 309 |
+
sdp_ratio,
|
| 310 |
+
noise_scale,
|
| 311 |
+
noise_scale_w,
|
| 312 |
+
length_scale,
|
| 313 |
+
language,
|
| 314 |
+
reference_audio,
|
| 315 |
+
emotion,
|
| 316 |
+
prompt_mode,
|
| 317 |
+
style_text=None,
|
| 318 |
+
style_weight=0,
|
| 319 |
+
):
|
| 320 |
+
if style_text == "":
|
| 321 |
+
style_text = None
|
| 322 |
+
if prompt_mode == "Audio prompt":
|
| 323 |
+
if reference_audio == None:
|
| 324 |
+
return ("Invalid audio prompt", None)
|
| 325 |
+
else:
|
| 326 |
+
reference_audio = load_audio(reference_audio)[1]
|
| 327 |
+
else:
|
| 328 |
+
reference_audio = None
|
| 329 |
+
|
| 330 |
+
audio_list = process_text(
|
| 331 |
+
text,
|
| 332 |
+
speaker,
|
| 333 |
+
sdp_ratio,
|
| 334 |
+
noise_scale,
|
| 335 |
+
noise_scale_w,
|
| 336 |
+
length_scale,
|
| 337 |
+
language,
|
| 338 |
+
reference_audio,
|
| 339 |
+
emotion,
|
| 340 |
+
style_text,
|
| 341 |
+
style_weight,
|
| 342 |
+
)
|
| 343 |
|
| 344 |
audio_concat = np.concatenate(audio_list)
|
| 345 |
return "Success", (hps.data.sampling_rate, audio_concat)
|
| 346 |
|
| 347 |
|
| 348 |
+
def format_utils(text, speaker):
|
| 349 |
+
_text, _lang = process_auto(text)
|
| 350 |
+
res = f"[{speaker}]"
|
| 351 |
+
for lang_s, content_s in zip(_lang, _text):
|
| 352 |
+
for lang, content in zip(lang_s, content_s):
|
| 353 |
+
res += f"<{lang.lower()}>{content}"
|
| 354 |
+
res += "|"
|
| 355 |
+
return "mix", res[:-1]
|
| 356 |
+
|
| 357 |
+
|
| 358 |
def load_audio(path):
|
| 359 |
audio, sr = librosa.load(path, 48000)
|
| 360 |
# audio = librosa.resample(audio, 44100, 48000)
|
|
|
|
| 390 |
with gr.Blocks() as app:
|
| 391 |
with gr.Row():
|
| 392 |
with gr.Column():
|
| 393 |
+
gr.Markdown(value="""
|
| 394 |
+
【AI阿梓】在线语音合成(Bert-Vits2 2.3中日英)\n
|
| 395 |
作者:Xz乔希 https://space.bilibili.com/5859321\n
|
| 396 |
+
声音归属:阿梓从小就很可爱 https://space.bilibili.com/7706705\n
|
| 397 |
【AI合集】https://www.modelscope.cn/studios/xzjosh/Bert-VITS2\n
|
| 398 |
Bert-VITS2项目:https://github.com/Stardust-minus/Bert-VITS2\n
|
| 399 |
使用本模型请严格遵守法律法规!\n
|
|
|
|
| 410 |
另外,所有的语言选项都可以用'|'分割长段实现分句生成。
|
| 411 |
""",
|
| 412 |
)
|
| 413 |
+
formatter = gr.Button("检测语言,并整理为 MIX 格式", variant="primary")
|
| 414 |
speaker = gr.Dropdown(
|
| 415 |
choices=speakers, value=speakers[0], label="Speaker"
|
| 416 |
)
|
| 417 |
_ = gr.Markdown(
|
| 418 |
+
value="提示模式(Prompt mode):可选文字提示或音频提示,用于生成文字或音频指定风格的声音。\n",
|
| 419 |
+
visible=False,
|
| 420 |
)
|
| 421 |
prompt_mode = gr.Radio(
|
| 422 |
["Text prompt", "Audio prompt"],
|
| 423 |
label="Prompt Mode",
|
| 424 |
value="Text prompt",
|
| 425 |
+
visible=False,
|
| 426 |
)
|
| 427 |
text_prompt = gr.Textbox(
|
| 428 |
label="Text prompt",
|
| 429 |
+
placeholder="用文字描述生成风格。如:Happy",
|
| 430 |
+
value="Happy",
|
| 431 |
+
visible=False,
|
| 432 |
)
|
| 433 |
audio_prompt = gr.Audio(
|
| 434 |
label="Audio prompt", type="filepath", visible=False
|
| 435 |
)
|
| 436 |
sdp_ratio = gr.Slider(
|
| 437 |
+
minimum=0, maximum=1, value=0.5, step=0.01, label="SDP Ratio"
|
| 438 |
)
|
| 439 |
noise_scale = gr.Slider(
|
| 440 |
minimum=0.1, maximum=2, value=0.5, step=0.01, label="Noise"
|
|
|
|
| 450 |
)
|
| 451 |
btn = gr.Button("点击生成", variant="primary")
|
| 452 |
with gr.Column():
|
| 453 |
+
with gr.Accordion("融合文本语义(实验功能)", open=False):
|
| 454 |
+
gr.Markdown(
|
| 455 |
+
value="使用辅助文本的语意来辅助生成对话(语言保持与主文本相同)\n\n"
|
| 456 |
+
"**注意**:请使用**带有强烈情感的文本**(如:我好快乐!)\n\n"
|
| 457 |
+
"效果较不明确,留空即为不使用该功能"
|
| 458 |
+
)
|
| 459 |
+
style_text = gr.Textbox(label="辅助文本")
|
| 460 |
+
style_weight = gr.Slider(
|
| 461 |
+
minimum=0,
|
| 462 |
+
maximum=1,
|
| 463 |
+
value=0.7,
|
| 464 |
+
step=0.1,
|
| 465 |
+
label="Weight",
|
| 466 |
+
info="主文本和辅助文本的bert混合比率,0表示仅主文本,1表示仅辅助文本",
|
| 467 |
+
)
|
| 468 |
with gr.Row():
|
| 469 |
with gr.Column():
|
| 470 |
interval_between_sent = gr.Slider(
|
|
|
|
| 507 |
audio_prompt,
|
| 508 |
text_prompt,
|
| 509 |
prompt_mode,
|
| 510 |
+
style_text,
|
| 511 |
+
style_weight,
|
| 512 |
],
|
| 513 |
outputs=[text_output, audio_output],
|
| 514 |
)
|
|
|
|
| 527 |
interval_between_sent,
|
| 528 |
audio_prompt,
|
| 529 |
text_prompt,
|
| 530 |
+
style_text,
|
| 531 |
+
style_weight,
|
| 532 |
],
|
| 533 |
outputs=[text_output, audio_output],
|
| 534 |
)
|
bert_gen.py
CHANGED
|
@@ -1,17 +1,16 @@
|
|
| 1 |
-
import argparse
|
| 2 |
-
from multiprocessing import Pool, cpu_count
|
| 3 |
-
|
| 4 |
import torch
|
| 5 |
-
|
| 6 |
-
from tqdm import tqdm
|
| 7 |
-
|
| 8 |
import commons
|
| 9 |
import utils
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
from config import config
|
| 11 |
-
from text import cleaned_text_to_sequence, get_bert
|
| 12 |
|
| 13 |
|
| 14 |
-
def process_line(
|
|
|
|
| 15 |
device = config.bert_gen_config.device
|
| 16 |
if config.bert_gen_config.use_multi_device:
|
| 17 |
rank = mp.current_process()._identity
|
|
@@ -28,12 +27,13 @@ def process_line(line):
|
|
| 28 |
word2ph = [i for i in word2ph]
|
| 29 |
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
|
| 30 |
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
|
|
|
| 37 |
|
| 38 |
bert_path = wav_path.replace(".WAV", ".wav").replace(".wav", ".bert.pt")
|
| 39 |
|
|
@@ -59,16 +59,23 @@ if __name__ == "__main__":
|
|
| 59 |
args, _ = parser.parse_known_args()
|
| 60 |
config_path = args.config
|
| 61 |
hps = utils.get_hparams_from_file(config_path)
|
|
|
|
| 62 |
lines = []
|
| 63 |
with open(hps.data.training_files, encoding="utf-8") as f:
|
| 64 |
lines.extend(f.readlines())
|
| 65 |
|
| 66 |
with open(hps.data.validation_files, encoding="utf-8") as f:
|
| 67 |
lines.extend(f.readlines())
|
|
|
|
|
|
|
| 68 |
if len(lines) != 0:
|
| 69 |
-
num_processes =
|
| 70 |
with Pool(processes=num_processes) as pool:
|
| 71 |
-
for _ in tqdm(
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
print(f"bert生成完毕!, 共有{len(lines)}个bert.pt生成!")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from multiprocessing import Pool
|
|
|
|
|
|
|
| 3 |
import commons
|
| 4 |
import utils
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
from text import check_bert_models, cleaned_text_to_sequence, get_bert
|
| 7 |
+
import argparse
|
| 8 |
+
import torch.multiprocessing as mp
|
| 9 |
from config import config
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
+
def process_line(x):
|
| 13 |
+
line, add_blank = x
|
| 14 |
device = config.bert_gen_config.device
|
| 15 |
if config.bert_gen_config.use_multi_device:
|
| 16 |
rank = mp.current_process()._identity
|
|
|
|
| 27 |
word2ph = [i for i in word2ph]
|
| 28 |
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
|
| 29 |
|
| 30 |
+
if add_blank:
|
| 31 |
+
phone = commons.intersperse(phone, 0)
|
| 32 |
+
tone = commons.intersperse(tone, 0)
|
| 33 |
+
language = commons.intersperse(language, 0)
|
| 34 |
+
for i in range(len(word2ph)):
|
| 35 |
+
word2ph[i] = word2ph[i] * 2
|
| 36 |
+
word2ph[0] += 1
|
| 37 |
|
| 38 |
bert_path = wav_path.replace(".WAV", ".wav").replace(".wav", ".bert.pt")
|
| 39 |
|
|
|
|
| 59 |
args, _ = parser.parse_known_args()
|
| 60 |
config_path = args.config
|
| 61 |
hps = utils.get_hparams_from_file(config_path)
|
| 62 |
+
check_bert_models()
|
| 63 |
lines = []
|
| 64 |
with open(hps.data.training_files, encoding="utf-8") as f:
|
| 65 |
lines.extend(f.readlines())
|
| 66 |
|
| 67 |
with open(hps.data.validation_files, encoding="utf-8") as f:
|
| 68 |
lines.extend(f.readlines())
|
| 69 |
+
add_blank = [hps.data.add_blank] * len(lines)
|
| 70 |
+
|
| 71 |
if len(lines) != 0:
|
| 72 |
+
num_processes = args.num_processes
|
| 73 |
with Pool(processes=num_processes) as pool:
|
| 74 |
+
for _ in tqdm(
|
| 75 |
+
pool.imap_unordered(process_line, zip(lines, add_blank)),
|
| 76 |
+
total=len(lines),
|
| 77 |
+
):
|
| 78 |
+
# 这里是缩进的代码块,表示循环体
|
| 79 |
+
pass # 使用pass语句作为占位符
|
| 80 |
|
| 81 |
print(f"bert生成完毕!, 共有{len(lines)}个bert.pt生成!")
|
clap_gen.py
CHANGED
|
@@ -27,7 +27,7 @@ def process_line(line):
|
|
| 27 |
device = torch.device("cpu")
|
| 28 |
wav_path, _, language_str, text, phones, tone, word2ph = line.strip().split("|")
|
| 29 |
|
| 30 |
-
clap_path = wav_path.replace(".WAV", ".wav").replace(".wav", ".emo.
|
| 31 |
if os.path.isfile(clap_path):
|
| 32 |
return
|
| 33 |
|
|
|
|
| 27 |
device = torch.device("cpu")
|
| 28 |
wav_path, _, language_str, text, phones, tone, word2ph = line.strip().split("|")
|
| 29 |
|
| 30 |
+
clap_path = wav_path.replace(".WAV", ".wav").replace(".wav", ".emo.pt")
|
| 31 |
if os.path.isfile(clap_path):
|
| 32 |
return
|
| 33 |
|
configs/config.json
CHANGED
|
@@ -10,18 +10,20 @@
|
|
| 10 |
0.99
|
| 11 |
],
|
| 12 |
"eps": 1e-09,
|
| 13 |
-
"batch_size":
|
| 14 |
-
"
|
| 15 |
"lr_decay": 0.99995,
|
| 16 |
"segment_size": 16384,
|
| 17 |
"init_lr_ratio": 1,
|
| 18 |
"warmup_epochs": 0,
|
| 19 |
"c_mel": 45,
|
| 20 |
"c_kl": 1.0,
|
|
|
|
| 21 |
"skip_optimizer": true,
|
| 22 |
"freeze_ZH_bert": false,
|
| 23 |
"freeze_JP_bert": false,
|
| 24 |
-
"freeze_EN_bert": false
|
|
|
|
| 25 |
},
|
| 26 |
"data": {
|
| 27 |
"training_files": "filelists/train.list",
|
|
@@ -35,7 +37,7 @@
|
|
| 35 |
"mel_fmin": 0.0,
|
| 36 |
"mel_fmax": null,
|
| 37 |
"add_blank": true,
|
| 38 |
-
"n_speakers":
|
| 39 |
"cleaned_text": true,
|
| 40 |
"spk2id": {
|
| 41 |
"派蒙_ZH": 0,
|
|
@@ -119,203 +121,203 @@
|
|
| 119 |
"伊迪娅_ZH": 78,
|
| 120 |
"留云借风真君_ZH": 79,
|
| 121 |
"绮良良_ZH": 80,
|
| 122 |
-
"
|
| 123 |
-
"
|
| 124 |
-
"
|
| 125 |
-
"
|
| 126 |
-
"
|
| 127 |
-
"
|
| 128 |
-
"
|
| 129 |
-
"
|
| 130 |
-
"
|
| 131 |
-
"
|
| 132 |
-
"
|
| 133 |
-
"
|
| 134 |
-
"
|
| 135 |
-
"
|
| 136 |
-
"
|
| 137 |
-
"
|
| 138 |
-
"
|
| 139 |
-
"
|
| 140 |
-
"
|
| 141 |
-
"
|
| 142 |
-
"
|
| 143 |
-
"
|
| 144 |
-
"
|
| 145 |
-
"
|
| 146 |
-
"
|
| 147 |
-
"
|
| 148 |
-
"
|
| 149 |
-
"
|
| 150 |
-
"
|
| 151 |
-
"
|
| 152 |
-
"
|
| 153 |
-
"
|
| 154 |
-
"
|
| 155 |
-
"
|
| 156 |
-
"
|
| 157 |
-
"
|
| 158 |
-
"
|
| 159 |
-
"
|
| 160 |
-
"
|
| 161 |
-
"
|
| 162 |
-
"
|
| 163 |
-
"
|
| 164 |
-
"
|
| 165 |
-
"
|
| 166 |
-
"
|
| 167 |
-
"
|
| 168 |
-
"
|
| 169 |
-
"
|
| 170 |
-
"
|
| 171 |
-
"
|
| 172 |
-
"
|
| 173 |
-
"
|
| 174 |
-
"
|
| 175 |
-
"
|
| 176 |
-
"
|
| 177 |
-
"
|
| 178 |
-
"
|
| 179 |
-
"
|
| 180 |
-
"
|
| 181 |
-
"
|
| 182 |
-
"
|
| 183 |
-
"
|
| 184 |
-
"
|
| 185 |
-
"
|
| 186 |
-
"
|
| 187 |
-
"
|
| 188 |
-
"
|
| 189 |
-
"
|
| 190 |
-
"
|
| 191 |
-
"
|
| 192 |
-
"
|
| 193 |
-
"
|
| 194 |
-
"
|
| 195 |
-
"
|
| 196 |
-
"
|
| 197 |
-
"
|
| 198 |
-
"
|
| 199 |
-
"
|
| 200 |
-
"
|
| 201 |
-
"
|
| 202 |
-
"
|
| 203 |
-
"
|
| 204 |
-
"
|
| 205 |
-
"
|
| 206 |
-
"
|
| 207 |
-
"
|
| 208 |
-
"
|
| 209 |
-
"
|
| 210 |
-
"
|
| 211 |
-
"
|
| 212 |
-
"
|
| 213 |
-
"
|
| 214 |
-
"
|
| 215 |
-
"
|
| 216 |
-
"
|
| 217 |
-
"
|
| 218 |
-
"
|
| 219 |
-
"
|
| 220 |
-
"
|
| 221 |
-
"
|
| 222 |
-
"
|
| 223 |
-
"
|
| 224 |
-
"
|
| 225 |
-
"
|
| 226 |
-
"
|
| 227 |
-
"
|
| 228 |
-
"
|
| 229 |
-
"
|
| 230 |
-
"
|
| 231 |
"阿佩普_ZH": 190,
|
| 232 |
"埃尔欣根_ZH": 191,
|
| 233 |
"萨赫哈蒂_ZH": 192,
|
| 234 |
"塔杰·拉德卡尼_ZH": 193,
|
| 235 |
"安西_ZH": 194,
|
| 236 |
-
"
|
| 237 |
-
"
|
| 238 |
-
"
|
| 239 |
-
"
|
| 240 |
-
"
|
| 241 |
-
"
|
| 242 |
-
"
|
| 243 |
-
"
|
| 244 |
-
"
|
| 245 |
-
"
|
| 246 |
-
"
|
| 247 |
-
"
|
| 248 |
-
"
|
| 249 |
-
"
|
| 250 |
-
"
|
| 251 |
-
"
|
| 252 |
-
"
|
| 253 |
-
"
|
| 254 |
-
"
|
| 255 |
-
"
|
| 256 |
-
"
|
| 257 |
-
"
|
| 258 |
-
"
|
| 259 |
-
"
|
| 260 |
-
"
|
| 261 |
-
"
|
| 262 |
-
"
|
| 263 |
-
"
|
| 264 |
-
"
|
| 265 |
-
"
|
| 266 |
-
"
|
| 267 |
-
"
|
| 268 |
-
"
|
| 269 |
-
"
|
| 270 |
-
"
|
| 271 |
-
"
|
| 272 |
-
"
|
| 273 |
-
"
|
| 274 |
-
"
|
| 275 |
-
"
|
| 276 |
-
"
|
| 277 |
-
"
|
| 278 |
-
"
|
| 279 |
-
"
|
| 280 |
-
"
|
| 281 |
-
"
|
| 282 |
-
"
|
| 283 |
-
"
|
| 284 |
-
"
|
| 285 |
-
"
|
| 286 |
-
"
|
| 287 |
-
"
|
| 288 |
-
"
|
| 289 |
-
"
|
| 290 |
-
"
|
| 291 |
-
"
|
| 292 |
-
"
|
| 293 |
-
"
|
| 294 |
-
"
|
| 295 |
-
"
|
| 296 |
-
"
|
| 297 |
-
"
|
| 298 |
-
"
|
| 299 |
-
"
|
| 300 |
-
"
|
| 301 |
-
"
|
| 302 |
-
"
|
| 303 |
-
"
|
| 304 |
-
"
|
| 305 |
-
"
|
| 306 |
-
"
|
| 307 |
-
"
|
| 308 |
-
"
|
| 309 |
-
"
|
| 310 |
-
"
|
| 311 |
-
"
|
| 312 |
-
"
|
| 313 |
-
"
|
| 314 |
-
"
|
| 315 |
-
"
|
| 316 |
-
"
|
| 317 |
-
"
|
| 318 |
-
"
|
| 319 |
"七七_JP": 278,
|
| 320 |
"式大将_JP": 279,
|
| 321 |
"瑶瑶_JP": 280,
|
|
@@ -323,576 +325,571 @@
|
|
| 323 |
"菲米尼_JP": 282,
|
| 324 |
"米卡_JP": 283,
|
| 325 |
"哲平_JP": 284,
|
| 326 |
-
"
|
| 327 |
-
"
|
| 328 |
-
"
|
| 329 |
-
"
|
| 330 |
-
"
|
| 331 |
-
"
|
| 332 |
-
"
|
| 333 |
-
"
|
| 334 |
-
"
|
| 335 |
-
"
|
| 336 |
-
"
|
| 337 |
-
"
|
| 338 |
-
"
|
| 339 |
-
"
|
| 340 |
-
"
|
| 341 |
-
"
|
| 342 |
-
"
|
| 343 |
-
"
|
| 344 |
-
"
|
| 345 |
-
"
|
| 346 |
-
"
|
| 347 |
-
"
|
| 348 |
-
"
|
| 349 |
-
"
|
| 350 |
-
"
|
| 351 |
-
"
|
| 352 |
-
"
|
| 353 |
-
"
|
| 354 |
-
"
|
| 355 |
-
"
|
| 356 |
-
"
|
| 357 |
-
"
|
| 358 |
-
"
|
| 359 |
-
"
|
| 360 |
-
"
|
| 361 |
-
"
|
| 362 |
-
"
|
| 363 |
-
"
|
| 364 |
-
"
|
| 365 |
-
"
|
| 366 |
-
"
|
| 367 |
-
"
|
| 368 |
-
"
|
| 369 |
-
"
|
| 370 |
-
"
|
| 371 |
-
"
|
| 372 |
-
"
|
| 373 |
-
"
|
| 374 |
-
"
|
| 375 |
-
"
|
| 376 |
-
"
|
| 377 |
-
"
|
| 378 |
-
"
|
| 379 |
-
"
|
| 380 |
-
"
|
| 381 |
-
"
|
| 382 |
-
"
|
| 383 |
-
"
|
| 384 |
-
"
|
| 385 |
-
"
|
| 386 |
-
"
|
| 387 |
-
"
|
| 388 |
-
"
|
| 389 |
-
"
|
| 390 |
-
"
|
| 391 |
-
"
|
| 392 |
-
"
|
| 393 |
-
"
|
| 394 |
-
"
|
| 395 |
-
"
|
| 396 |
-
"
|
| 397 |
-
"
|
| 398 |
-
"
|
| 399 |
-
"
|
| 400 |
-
"
|
| 401 |
-
"
|
| 402 |
-
"
|
| 403 |
-
"
|
| 404 |
-
"
|
| 405 |
-
"
|
| 406 |
-
"
|
| 407 |
-
"
|
| 408 |
-
"
|
| 409 |
-
"
|
| 410 |
-
"
|
| 411 |
-
"
|
| 412 |
-
"
|
| 413 |
-
"
|
| 414 |
-
"
|
| 415 |
-
"
|
| 416 |
-
"
|
| 417 |
-
"
|
| 418 |
-
"
|
| 419 |
-
"
|
| 420 |
-
"
|
| 421 |
-
"
|
| 422 |
-
"
|
| 423 |
-
"
|
| 424 |
-
"
|
| 425 |
-
"
|
| 426 |
-
"
|
| 427 |
-
"
|
| 428 |
-
"
|
| 429 |
-
"
|
| 430 |
-
"
|
| 431 |
-
"
|
| 432 |
-
"
|
| 433 |
-
"
|
| 434 |
-
"
|
| 435 |
-
"
|
| 436 |
-
"
|
| 437 |
-
"
|
| 438 |
-
"
|
| 439 |
-
"
|
| 440 |
-
"
|
| 441 |
-
"
|
| 442 |
-
"
|
| 443 |
-
"
|
| 444 |
-
"
|
| 445 |
-
"
|
| 446 |
-
"
|
| 447 |
-
"
|
| 448 |
-
"
|
| 449 |
-
"
|
| 450 |
-
"
|
| 451 |
-
"
|
| 452 |
-
"
|
| 453 |
-
"
|
| 454 |
-
"
|
| 455 |
-
"
|
| 456 |
-
"
|
| 457 |
-
"
|
| 458 |
-
"
|
| 459 |
-
"
|
| 460 |
-
"
|
| 461 |
-
"
|
| 462 |
-
"
|
| 463 |
-
"
|
| 464 |
-
"
|
| 465 |
-
"
|
| 466 |
-
"
|
| 467 |
-
"
|
| 468 |
-
"
|
| 469 |
-
"
|
| 470 |
-
"
|
| 471 |
-
"
|
| 472 |
-
"
|
| 473 |
-
"
|
| 474 |
-
"
|
| 475 |
-
"
|
| 476 |
-
"
|
| 477 |
-
"
|
| 478 |
-
"
|
| 479 |
-
"
|
| 480 |
-
"
|
| 481 |
-
"
|
| 482 |
-
"
|
| 483 |
-
"
|
| 484 |
-
"
|
| 485 |
-
"
|
| 486 |
-
"
|
| 487 |
-
"
|
| 488 |
-
"
|
| 489 |
-
"
|
| 490 |
-
"
|
| 491 |
-
"
|
| 492 |
-
"
|
| 493 |
-
"
|
| 494 |
-
"
|
| 495 |
-
"
|
| 496 |
-
"
|
| 497 |
-
"
|
| 498 |
-
"
|
| 499 |
-
"
|
| 500 |
-
"
|
| 501 |
-
"
|
| 502 |
-
"
|
| 503 |
-
"
|
| 504 |
-
"
|
| 505 |
-
"
|
| 506 |
-
"
|
| 507 |
-
"
|
| 508 |
-
"
|
| 509 |
-
"
|
| 510 |
-
"
|
| 511 |
-
"
|
| 512 |
-
"
|
| 513 |
-
"
|
| 514 |
-
"
|
| 515 |
-
"
|
| 516 |
-
"
|
| 517 |
-
"
|
| 518 |
-
"
|
| 519 |
-
"
|
| 520 |
-
"
|
| 521 |
-
"
|
| 522 |
-
"
|
| 523 |
-
"
|
| 524 |
-
"
|
| 525 |
-
"
|
| 526 |
-
"
|
| 527 |
-
"
|
| 528 |
-
"
|
| 529 |
-
"
|
| 530 |
-
"
|
| 531 |
-
"
|
| 532 |
-
"
|
| 533 |
-
"
|
| 534 |
-
"
|
| 535 |
-
"
|
| 536 |
-
"
|
| 537 |
-
"
|
| 538 |
-
"
|
| 539 |
-
"
|
| 540 |
-
"
|
| 541 |
-
"
|
| 542 |
-
"
|
| 543 |
-
"
|
| 544 |
-
"
|
| 545 |
-
"
|
| 546 |
-
"
|
| 547 |
-
"
|
| 548 |
-
"
|
| 549 |
-
"
|
| 550 |
-
"
|
| 551 |
-
"
|
| 552 |
-
"
|
| 553 |
-
"
|
| 554 |
-
"
|
| 555 |
-
"
|
| 556 |
-
"
|
| 557 |
-
"
|
| 558 |
-
"
|
| 559 |
-
"
|
| 560 |
-
"
|
| 561 |
-
"
|
| 562 |
-
"
|
| 563 |
-
"
|
| 564 |
-
"
|
| 565 |
-
"
|
| 566 |
-
"
|
| 567 |
-
"
|
| 568 |
-
"
|
| 569 |
-
"
|
| 570 |
-
"
|
| 571 |
-
"
|
| 572 |
-
"
|
| 573 |
-
"
|
| 574 |
-
"
|
| 575 |
-
"
|
| 576 |
-
"
|
| 577 |
-
"
|
| 578 |
-
"
|
| 579 |
-
"
|
| 580 |
-
"
|
| 581 |
-
"
|
| 582 |
-
"
|
| 583 |
-
"
|
| 584 |
-
"
|
| 585 |
-
"
|
| 586 |
-
"
|
| 587 |
-
"
|
| 588 |
-
"
|
| 589 |
-
"
|
| 590 |
-
"
|
| 591 |
-
"
|
| 592 |
-
"
|
| 593 |
-
"
|
| 594 |
-
"
|
| 595 |
-
"
|
| 596 |
-
"
|
| 597 |
-
"
|
| 598 |
-
"
|
| 599 |
-
"
|
| 600 |
-
"
|
| 601 |
-
"
|
| 602 |
-
"
|
| 603 |
-
"
|
| 604 |
-
"
|
| 605 |
-
"
|
| 606 |
-
"
|
| 607 |
-
"
|
| 608 |
-
"
|
| 609 |
-
"
|
| 610 |
-
"
|
| 611 |
-
"
|
| 612 |
-
"
|
| 613 |
-
"
|
| 614 |
-
"
|
| 615 |
-
"
|
| 616 |
-
"
|
| 617 |
-
"
|
| 618 |
-
"
|
| 619 |
-
"
|
| 620 |
-
"
|
| 621 |
-
"
|
| 622 |
-
"
|
| 623 |
-
"
|
| 624 |
-
"
|
| 625 |
-
"
|
| 626 |
-
"
|
| 627 |
-
"
|
| 628 |
-
"
|
| 629 |
-
"
|
| 630 |
-
"
|
| 631 |
-
"
|
| 632 |
-
"
|
| 633 |
-
"
|
| 634 |
-
"
|
| 635 |
-
"
|
| 636 |
-
"
|
| 637 |
-
"
|
| 638 |
-
"
|
| 639 |
-
"
|
| 640 |
-
"
|
| 641 |
-
"
|
| 642 |
-
"
|
| 643 |
-
"
|
| 644 |
-
"
|
| 645 |
-
"
|
| 646 |
-
"
|
| 647 |
-
"
|
| 648 |
-
"
|
| 649 |
-
"
|
| 650 |
-
"
|
| 651 |
-
"
|
| 652 |
-
"
|
| 653 |
-
"
|
| 654 |
-
"
|
| 655 |
-
"
|
| 656 |
-
"
|
| 657 |
-
"
|
| 658 |
-
"
|
| 659 |
-
"
|
| 660 |
-
"
|
| 661 |
-
"
|
| 662 |
-
"
|
| 663 |
-
"
|
| 664 |
-
"
|
| 665 |
-
"
|
| 666 |
-
"
|
| 667 |
-
"
|
| 668 |
-
"
|
| 669 |
-
"
|
| 670 |
-
"
|
| 671 |
-
"
|
| 672 |
-
"
|
| 673 |
-
"
|
| 674 |
-
"
|
| 675 |
-
"
|
| 676 |
-
"
|
| 677 |
-
"
|
| 678 |
-
"
|
| 679 |
"埃舍尔_EN": 638,
|
| 680 |
-
"
|
| 681 |
-
"
|
| 682 |
-
"
|
| 683 |
-
"
|
| 684 |
-
"
|
| 685 |
-
"
|
| 686 |
-
"
|
| 687 |
-
"
|
| 688 |
-
"
|
| 689 |
-
"
|
| 690 |
-
"
|
| 691 |
-
"
|
| 692 |
-
"
|
| 693 |
-
"
|
| 694 |
-
"
|
| 695 |
-
"
|
| 696 |
-
"
|
| 697 |
-
"
|
| 698 |
-
"
|
| 699 |
-
"
|
| 700 |
-
"
|
| 701 |
-
"
|
| 702 |
-
"
|
| 703 |
-
"
|
| 704 |
-
"
|
| 705 |
-
"
|
| 706 |
-
"
|
| 707 |
-
"
|
| 708 |
-
"
|
| 709 |
-
"
|
| 710 |
-
"
|
| 711 |
-
"
|
| 712 |
-
"
|
| 713 |
-
"
|
| 714 |
-
"
|
| 715 |
-
"
|
| 716 |
-
"
|
| 717 |
-
"
|
| 718 |
-
"
|
| 719 |
-
"
|
| 720 |
-
"
|
| 721 |
-
"
|
| 722 |
-
"
|
| 723 |
-
"
|
| 724 |
-
"
|
| 725 |
-
"
|
| 726 |
-
"
|
| 727 |
-
"
|
| 728 |
-
"
|
| 729 |
-
"
|
| 730 |
-
"
|
| 731 |
-
"
|
| 732 |
-
"
|
| 733 |
-
"
|
| 734 |
-
"
|
| 735 |
-
"
|
| 736 |
-
"
|
| 737 |
-
"
|
| 738 |
-
"
|
| 739 |
-
"
|
| 740 |
-
"
|
| 741 |
-
"
|
| 742 |
-
"
|
| 743 |
-
"
|
| 744 |
-
"
|
| 745 |
-
"
|
| 746 |
-
"
|
| 747 |
-
"
|
| 748 |
-
"
|
| 749 |
-
"
|
| 750 |
-
"
|
| 751 |
-
"
|
| 752 |
-
"
|
| 753 |
-
"
|
| 754 |
-
"
|
| 755 |
-
"
|
| 756 |
-
"
|
| 757 |
-
"
|
| 758 |
-
"
|
| 759 |
-
"
|
| 760 |
-
"
|
| 761 |
-
"
|
| 762 |
-
"
|
| 763 |
-
"
|
| 764 |
-
"
|
| 765 |
-
"
|
| 766 |
-
"
|
| 767 |
-
"
|
| 768 |
-
"
|
| 769 |
-
"
|
| 770 |
-
"
|
| 771 |
-
"
|
| 772 |
-
"
|
| 773 |
-
"
|
| 774 |
-
"
|
| 775 |
-
"
|
| 776 |
-
"
|
| 777 |
-
"
|
| 778 |
-
"
|
| 779 |
-
"
|
| 780 |
-
"
|
| 781 |
-
"
|
| 782 |
-
"
|
| 783 |
-
"
|
| 784 |
-
"
|
| 785 |
-
"
|
| 786 |
-
"
|
| 787 |
-
"
|
| 788 |
-
"
|
| 789 |
-
"
|
| 790 |
-
"
|
| 791 |
-
"
|
| 792 |
-
"
|
| 793 |
-
"
|
| 794 |
-
"
|
| 795 |
-
"
|
| 796 |
-
"
|
| 797 |
-
"
|
| 798 |
-
"
|
| 799 |
-
"
|
| 800 |
-
"
|
| 801 |
-
"
|
| 802 |
-
"
|
| 803 |
-
"
|
| 804 |
-
"
|
| 805 |
-
"
|
| 806 |
-
"
|
| 807 |
-
"
|
| 808 |
-
"
|
| 809 |
-
"
|
| 810 |
-
"
|
| 811 |
-
"
|
| 812 |
-
"
|
| 813 |
-
"
|
| 814 |
-
"
|
| 815 |
-
"
|
| 816 |
-
"
|
| 817 |
-
"
|
| 818 |
-
"
|
| 819 |
-
"
|
| 820 |
-
"
|
| 821 |
-
"
|
| 822 |
-
"
|
| 823 |
-
"
|
| 824 |
-
"
|
| 825 |
-
"
|
| 826 |
-
"
|
| 827 |
-
"
|
| 828 |
-
"
|
| 829 |
-
"
|
| 830 |
-
"
|
| 831 |
-
"
|
| 832 |
-
"
|
| 833 |
-
"
|
| 834 |
-
"
|
| 835 |
-
"
|
| 836 |
-
"
|
| 837 |
-
"
|
| 838 |
-
"
|
| 839 |
-
"
|
| 840 |
-
"
|
| 841 |
-
"
|
| 842 |
-
"
|
| 843 |
-
"
|
| 844 |
-
"
|
| 845 |
-
"
|
| 846 |
-
"
|
| 847 |
-
"
|
| 848 |
-
"
|
| 849 |
-
"
|
| 850 |
-
"
|
| 851 |
-
"
|
| 852 |
-
"
|
| 853 |
-
"
|
| 854 |
-
"
|
| 855 |
-
"
|
| 856 |
-
"
|
| 857 |
-
"
|
| 858 |
-
"
|
| 859 |
-
"
|
| 860 |
-
"
|
| 861 |
-
"
|
| 862 |
-
"
|
| 863 |
-
"
|
| 864 |
-
"
|
| 865 |
-
"
|
| 866 |
-
"
|
| 867 |
-
"
|
| 868 |
-
"
|
| 869 |
-
"
|
| 870 |
-
"
|
| 871 |
-
"
|
| 872 |
-
"
|
| 873 |
-
"
|
| 874 |
-
"
|
| 875 |
-
"
|
| 876 |
-
"
|
| 877 |
-
"
|
| 878 |
-
"
|
| 879 |
-
"
|
| 880 |
-
"
|
| 881 |
-
"
|
| 882 |
-
"
|
| 883 |
-
"
|
| 884 |
-
"
|
| 885 |
-
"
|
| 886 |
-
"
|
| 887 |
-
"
|
| 888 |
-
"
|
| 889 |
-
"
|
| 890 |
-
"
|
| 891 |
-
"女声_EN": 850,
|
| 892 |
-
"陆景和": 851,
|
| 893 |
-
"莫弈": 852,
|
| 894 |
-
"左然": 853,
|
| 895 |
-
"夏彦": 854
|
| 896 |
}
|
| 897 |
},
|
| 898 |
"model": {
|
|
@@ -947,7 +944,14 @@
|
|
| 947 |
],
|
| 948 |
"n_layers_q": 3,
|
| 949 |
"use_spectral_norm": false,
|
| 950 |
-
"gin_channels":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 951 |
},
|
| 952 |
-
"version": "2.
|
| 953 |
}
|
|
|
|
| 10 |
0.99
|
| 11 |
],
|
| 12 |
"eps": 1e-09,
|
| 13 |
+
"batch_size": 16,
|
| 14 |
+
"bf16_run": false,
|
| 15 |
"lr_decay": 0.99995,
|
| 16 |
"segment_size": 16384,
|
| 17 |
"init_lr_ratio": 1,
|
| 18 |
"warmup_epochs": 0,
|
| 19 |
"c_mel": 45,
|
| 20 |
"c_kl": 1.0,
|
| 21 |
+
"c_commit": 100,
|
| 22 |
"skip_optimizer": true,
|
| 23 |
"freeze_ZH_bert": false,
|
| 24 |
"freeze_JP_bert": false,
|
| 25 |
+
"freeze_EN_bert": false,
|
| 26 |
+
"freeze_emo": false
|
| 27 |
},
|
| 28 |
"data": {
|
| 29 |
"training_files": "filelists/train.list",
|
|
|
|
| 37 |
"mel_fmin": 0.0,
|
| 38 |
"mel_fmax": null,
|
| 39 |
"add_blank": true,
|
| 40 |
+
"n_speakers": 850,
|
| 41 |
"cleaned_text": true,
|
| 42 |
"spk2id": {
|
| 43 |
"派蒙_ZH": 0,
|
|
|
|
| 121 |
"伊迪娅_ZH": 78,
|
| 122 |
"留云借风真君_ZH": 79,
|
| 123 |
"绮良良_ZH": 80,
|
| 124 |
+
"陌生人_ZH": 81,
|
| 125 |
+
"七七_ZH": 82,
|
| 126 |
+
"式大将_ZH": 83,
|
| 127 |
+
"瑶瑶_ZH": 84,
|
| 128 |
+
"奥兹_ZH": 85,
|
| 129 |
+
"菲米尼_ZH": 86,
|
| 130 |
+
"米卡_ZH": 87,
|
| 131 |
+
"哲平_ZH": 88,
|
| 132 |
+
"浮游水蕈兽·元素生命_ZH": 89,
|
| 133 |
+
"大肉丸_ZH": 90,
|
| 134 |
+
"托克_ZH": 91,
|
| 135 |
+
"蒂玛乌斯_ZH": 92,
|
| 136 |
+
"昆钧_ZH": 93,
|
| 137 |
+
"欧菲妮_ZH": 94,
|
| 138 |
+
"塞琉斯_ZH": 95,
|
| 139 |
+
"仆人_ZH": 96,
|
| 140 |
+
"迈勒斯_ZH": 97,
|
| 141 |
+
"希格雯_ZH": 98,
|
| 142 |
+
"阿守_ZH": 99,
|
| 143 |
+
"拉赫曼_ZH": 100,
|
| 144 |
+
"杜拉夫_ZH": 101,
|
| 145 |
+
"伊利亚斯_ZH": 102,
|
| 146 |
+
"阿晃_ZH": 103,
|
| 147 |
+
"旁白_ZH": 104,
|
| 148 |
+
"爱德琳_ZH": 105,
|
| 149 |
+
"埃洛伊_ZH": 106,
|
| 150 |
+
"德沃沙克_ZH": 107,
|
| 151 |
+
"玛乔丽_ZH": 108,
|
| 152 |
+
"塞塔蕾_ZH": 109,
|
| 153 |
+
"柊千里_ZH": 110,
|
| 154 |
+
"海芭夏_ZH": 111,
|
| 155 |
+
"九条镰治_ZH": 112,
|
| 156 |
+
"阿娜耶_ZH": 113,
|
| 157 |
+
"笼钓瓶一心_ZH": 114,
|
| 158 |
+
"回声海螺_ZH": 115,
|
| 159 |
+
"劳维克_ZH": 116,
|
| 160 |
+
"元太_ZH": 117,
|
| 161 |
+
"阿扎尔_ZH": 118,
|
| 162 |
+
"查尔斯_ZH": 119,
|
| 163 |
+
"阿洛瓦_ZH": 120,
|
| 164 |
+
"埃勒曼_ZH": 121,
|
| 165 |
+
"纳比尔_ZH": 122,
|
| 166 |
+
"莎拉_ZH": 123,
|
| 167 |
+
"康纳_ZH": 124,
|
| 168 |
+
"博来_ZH": 125,
|
| 169 |
+
"玛塞勒_ZH": 126,
|
| 170 |
+
"阿祇_ZH": 127,
|
| 171 |
+
"博士_ZH": 128,
|
| 172 |
+
"玛格丽特_ZH": 129,
|
| 173 |
+
"迪尔菲_ZH": 130,
|
| 174 |
+
"宛烟_ZH": 131,
|
| 175 |
+
"羽生田千鹤_ZH": 132,
|
| 176 |
+
"海妮耶_ZH": 133,
|
| 177 |
+
"旅行者_ZH": 134,
|
| 178 |
+
"霍夫曼_ZH": 135,
|
| 179 |
+
"佐西摩斯_ZH": 136,
|
| 180 |
+
"鹿野奈奈_ZH": 137,
|
| 181 |
+
"舒伯特_ZH": 138,
|
| 182 |
+
"天叔_ZH": 139,
|
| 183 |
+
"艾莉丝_ZH": 140,
|
| 184 |
+
"龙二_ZH": 141,
|
| 185 |
+
"莺儿_ZH": 142,
|
| 186 |
+
"嘉良_ZH": 143,
|
| 187 |
+
"一心传名刀_ZH": 144,
|
| 188 |
+
"珊瑚_ZH": 145,
|
| 189 |
+
"言笑_ZH": 146,
|
| 190 |
+
"久利须_ZH": 147,
|
| 191 |
+
"嘉玛_ZH": 148,
|
| 192 |
+
"艾文_ZH": 149,
|
| 193 |
+
"克洛琳德_ZH": 150,
|
| 194 |
+
"丹吉尔_ZH": 151,
|
| 195 |
+
"女士_ZH": 152,
|
| 196 |
+
"白老先生_ZH": 153,
|
| 197 |
+
"天目十五_ZH": 154,
|
| 198 |
+
"老孟_ZH": 155,
|
| 199 |
+
"巴达维_ZH": 156,
|
| 200 |
+
"长生_ZH": 157,
|
| 201 |
+
"吴船长_ZH": 158,
|
| 202 |
+
"拉齐_ZH": 159,
|
| 203 |
+
"艾伯特_ZH": 160,
|
| 204 |
+
"松浦_ZH": 161,
|
| 205 |
+
"埃泽_ZH": 162,
|
| 206 |
+
"阿圆_ZH": 163,
|
| 207 |
+
"莫塞伊思_ZH": 164,
|
| 208 |
+
"阿拉夫_ZH": 165,
|
| 209 |
+
"杜吉耶_ZH": 166,
|
| 210 |
+
"石头_ZH": 167,
|
| 211 |
+
"百闻_ZH": 168,
|
| 212 |
+
"波洛_ZH": 169,
|
| 213 |
+
"斯坦利_ZH": 170,
|
| 214 |
+
"博易_ZH": 171,
|
| 215 |
+
"迈蒙_ZH": 172,
|
| 216 |
+
"掇星攫辰天君_ZH": 173,
|
| 217 |
+
"毗伽尔_ZH": 174,
|
| 218 |
+
"芙卡洛斯_ZH": 175,
|
| 219 |
+
"恶龙_ZH": 176,
|
| 220 |
+
"恕筠_ZH": 177,
|
| 221 |
+
"知易_ZH": 178,
|
| 222 |
+
"克列门特_ZH": 179,
|
| 223 |
+
"大慈树王_ZH": 180,
|
| 224 |
+
"西拉杰_ZH": 181,
|
| 225 |
+
"上杉_ZH": 182,
|
| 226 |
+
"阿尔卡米_ZH": 183,
|
| 227 |
+
"纯水精灵_ZH": 184,
|
| 228 |
+
"常九爷_ZH": 185,
|
| 229 |
+
"沙扎曼_ZH": 186,
|
| 230 |
+
"田铁嘴_ZH": 187,
|
| 231 |
+
"克罗索_ZH": 188,
|
| 232 |
+
"阿巴图伊_ZH": 189,
|
| 233 |
"阿佩普_ZH": 190,
|
| 234 |
"埃尔欣根_ZH": 191,
|
| 235 |
"萨赫哈蒂_ZH": 192,
|
| 236 |
"塔杰·拉德卡尼_ZH": 193,
|
| 237 |
"安西_ZH": 194,
|
| 238 |
+
"陆行岩本真蕈·元素生命_ZH": 195,
|
| 239 |
+
"派蒙_JP": 196,
|
| 240 |
+
"纳西妲_JP": 197,
|
| 241 |
+
"凯亚_JP": 198,
|
| 242 |
+
"阿贝多_JP": 199,
|
| 243 |
+
"温迪_JP": 200,
|
| 244 |
+
"枫原万叶_JP": 201,
|
| 245 |
+
"钟离_JP": 202,
|
| 246 |
+
"荒泷一斗_JP": 203,
|
| 247 |
+
"八重神子_JP": 204,
|
| 248 |
+
"艾尔海森_JP": 205,
|
| 249 |
+
"提纳里_JP": 206,
|
| 250 |
+
"迪希雅_JP": 207,
|
| 251 |
+
"卡维_JP": 208,
|
| 252 |
+
"宵宫_JP": 209,
|
| 253 |
+
"那维莱特_JP": 210,
|
| 254 |
+
"莱依拉_JP": 211,
|
| 255 |
+
"赛诺_JP": 212,
|
| 256 |
+
"莫娜_JP": 213,
|
| 257 |
+
"诺艾尔_JP": 214,
|
| 258 |
+
"托马_JP": 215,
|
| 259 |
+
"凝光_JP": 216,
|
| 260 |
+
"林尼_JP": 217,
|
| 261 |
+
"北斗_JP": 218,
|
| 262 |
+
"柯莱_JP": 219,
|
| 263 |
+
"神里绫华_JP": 220,
|
| 264 |
+
"可莉_JP": 221,
|
| 265 |
+
"芭芭拉_JP": 222,
|
| 266 |
+
"雷电将军_JP": 223,
|
| 267 |
+
"娜维娅_JP": 224,
|
| 268 |
+
"芙宁娜_JP": 225,
|
| 269 |
+
"珊瑚宫心海_JP": 226,
|
| 270 |
+
"鹿野院平藏_JP": 227,
|
| 271 |
+
"迪奥娜_JP": 228,
|
| 272 |
+
"琴_JP": 229,
|
| 273 |
+
"五郎_JP": 230,
|
| 274 |
+
"班尼特_JP": 231,
|
| 275 |
+
"达达利亚_JP": 232,
|
| 276 |
+
"安柏_JP": 233,
|
| 277 |
+
"莱欧斯利_JP": 234,
|
| 278 |
+
"夜兰_JP": 235,
|
| 279 |
+
"妮露_JP": 236,
|
| 280 |
+
"辛焱_JP": 237,
|
| 281 |
+
"丽莎_JP": 238,
|
| 282 |
+
"珐露珊_JP": 239,
|
| 283 |
+
"魈_JP": 240,
|
| 284 |
+
"香菱_JP": 241,
|
| 285 |
+
"迪卢克_JP": 242,
|
| 286 |
+
"砂糖_JP": 243,
|
| 287 |
+
"烟绯_JP": 244,
|
| 288 |
+
"早柚_JP": 245,
|
| 289 |
+
"云堇_JP": 246,
|
| 290 |
+
"刻晴_JP": 247,
|
| 291 |
+
"重云_JP": 248,
|
| 292 |
+
"优菈_JP": 249,
|
| 293 |
+
"胡桃_JP": 250,
|
| 294 |
+
"流浪者_JP": 251,
|
| 295 |
+
"久岐忍_JP": 252,
|
| 296 |
+
"神里绫人_JP": 253,
|
| 297 |
+
"甘雨_JP": 254,
|
| 298 |
+
"戴因斯雷布_JP": 255,
|
| 299 |
+
"菲谢尔_JP": 256,
|
| 300 |
+
"白术_JP": 257,
|
| 301 |
+
"行秋_JP": 258,
|
| 302 |
+
"九条裟罗_JP": 259,
|
| 303 |
+
"夏洛蒂_JP": 260,
|
| 304 |
+
"雷泽_JP": 261,
|
| 305 |
+
"申鹤_JP": 262,
|
| 306 |
+
"空_JP": 263,
|
| 307 |
+
"荧_JP": 264,
|
| 308 |
+
"迪娜泽黛_JP": 265,
|
| 309 |
+
"凯瑟琳_JP": 266,
|
| 310 |
+
"多莉_JP": 267,
|
| 311 |
+
"坎蒂丝_JP": 268,
|
| 312 |
+
"琳妮特_JP": 269,
|
| 313 |
+
"萍姥姥_JP": 270,
|
| 314 |
+
"罗莎莉亚_JP": 271,
|
| 315 |
+
"埃德_JP": 272,
|
| 316 |
+
"爱贝尔_JP": 273,
|
| 317 |
+
"伊迪娅_JP": 274,
|
| 318 |
+
"留云借风真君_JP": 275,
|
| 319 |
+
"绮良良_JP": 276,
|
| 320 |
+
"陌生人_JP": 277,
|
| 321 |
"七七_JP": 278,
|
| 322 |
"式大将_JP": 279,
|
| 323 |
"瑶瑶_JP": 280,
|
|
|
|
| 325 |
"菲米尼_JP": 282,
|
| 326 |
"米卡_JP": 283,
|
| 327 |
"哲平_JP": 284,
|
| 328 |
+
"浮游水蕈兽·元素生命_JP": 285,
|
| 329 |
+
"大肉丸_JP": 286,
|
| 330 |
+
"托克_JP": 287,
|
| 331 |
+
"蒂玛乌斯_JP": 288,
|
| 332 |
+
"昆钧_JP": 289,
|
| 333 |
+
"欧菲妮_JP": 290,
|
| 334 |
+
"塞琉斯_JP": 291,
|
| 335 |
+
"仆人_JP": 292,
|
| 336 |
+
"迈勒斯_JP": 293,
|
| 337 |
+
"希格雯_JP": 294,
|
| 338 |
+
"阿守_JP": 295,
|
| 339 |
+
"拉赫曼_JP": 296,
|
| 340 |
+
"杜拉夫_JP": 297,
|
| 341 |
+
"伊利亚斯_JP": 298,
|
| 342 |
+
"阿晃_JP": 299,
|
| 343 |
+
"旁白_JP": 300,
|
| 344 |
+
"爱德琳_JP": 301,
|
| 345 |
+
"埃洛伊_JP": 302,
|
| 346 |
+
"德沃沙克_JP": 303,
|
| 347 |
+
"玛乔丽_JP": 304,
|
| 348 |
+
"塞塔蕾_JP": 305,
|
| 349 |
+
"柊千里_JP": 306,
|
| 350 |
+
"海芭夏_JP": 307,
|
| 351 |
+
"九条镰治_JP": 308,
|
| 352 |
+
"阿娜耶_JP": 309,
|
| 353 |
+
"笼钓瓶一心_JP": 310,
|
| 354 |
+
"回声海螺_JP": 311,
|
| 355 |
+
"劳维克_JP": 312,
|
| 356 |
+
"元太_JP": 313,
|
| 357 |
+
"阿扎尔_JP": 314,
|
| 358 |
+
"查尔斯_JP": 315,
|
| 359 |
+
"阿洛瓦_JP": 316,
|
| 360 |
+
"埃勒曼_JP": 317,
|
| 361 |
+
"纳比尔_JP": 318,
|
| 362 |
+
"莎拉_JP": 319,
|
| 363 |
+
"康纳_JP": 320,
|
| 364 |
+
"博来_JP": 321,
|
| 365 |
+
"玛塞勒_JP": 322,
|
| 366 |
+
"阿祇_JP": 323,
|
| 367 |
+
"博士_JP": 324,
|
| 368 |
+
"迪尔菲_JP": 325,
|
| 369 |
+
"玛格丽特_JP": 326,
|
| 370 |
+
"宛烟_JP": 327,
|
| 371 |
+
"羽生田千鹤_JP": 328,
|
| 372 |
+
"海妮耶_JP": 329,
|
| 373 |
+
"霍夫曼_JP": 330,
|
| 374 |
+
"旅行者_JP": 331,
|
| 375 |
+
"佐西摩斯_JP": 332,
|
| 376 |
+
"舒伯特_JP": 333,
|
| 377 |
+
"鹿野奈奈_JP": 334,
|
| 378 |
+
"天叔_JP": 335,
|
| 379 |
+
"龙二_JP": 336,
|
| 380 |
+
"艾莉丝_JP": 337,
|
| 381 |
+
"莺儿_JP": 338,
|
| 382 |
+
"嘉良_JP": 339,
|
| 383 |
+
"珊瑚_JP": 340,
|
| 384 |
+
"言笑_JP": 341,
|
| 385 |
+
"一心传名刀_JP": 342,
|
| 386 |
+
"费迪南德_JP": 343,
|
| 387 |
+
"久利须_JP": 344,
|
| 388 |
+
"嘉玛_JP": 345,
|
| 389 |
+
"艾文_JP": 346,
|
| 390 |
+
"克洛琳德_JP": 347,
|
| 391 |
+
"丹吉尔_JP": 348,
|
| 392 |
+
"天目十五_JP": 349,
|
| 393 |
+
"女士_JP": 350,
|
| 394 |
+
"老孟_JP": 351,
|
| 395 |
+
"白老先生_JP": 352,
|
| 396 |
+
"舍利夫_JP": 353,
|
| 397 |
+
"巴达维_JP": 354,
|
| 398 |
+
"拉齐_JP": 355,
|
| 399 |
+
"长生_JP": 356,
|
| 400 |
+
"吴船长_JP": 357,
|
| 401 |
+
"艾伯特_JP": 358,
|
| 402 |
+
"松浦_JP": 359,
|
| 403 |
+
"埃泽_JP": 360,
|
| 404 |
+
"阿圆_JP": 361,
|
| 405 |
+
"阿拉夫_JP": 362,
|
| 406 |
+
"莫塞伊思_JP": 363,
|
| 407 |
+
"石头_JP": 364,
|
| 408 |
+
"百闻_JP": 365,
|
| 409 |
+
"杜吉耶_JP": 366,
|
| 410 |
+
"波洛_JP": 367,
|
| 411 |
+
"掇星攫辰天君_JP": 368,
|
| 412 |
+
"迈蒙_JP": 369,
|
| 413 |
+
"博易_JP": 370,
|
| 414 |
+
"诗筠_JP": 371,
|
| 415 |
+
"斯坦利_JP": 372,
|
| 416 |
+
"毗伽尔_JP": 373,
|
| 417 |
+
"芙卡洛斯_JP": 374,
|
| 418 |
+
"恶龙_JP": 375,
|
| 419 |
+
"小仓澪_JP": 376,
|
| 420 |
+
"恕筠_JP": 377,
|
| 421 |
+
"知易_JP": 378,
|
| 422 |
+
"克列门特_JP": 379,
|
| 423 |
+
"大慈树王_JP": 380,
|
| 424 |
+
"望雅_JP": 381,
|
| 425 |
+
"黑田_JP": 382,
|
| 426 |
+
"卡莉娜_JP": 383,
|
| 427 |
+
"马姆杜_JP": 384,
|
| 428 |
+
"科林斯_JP": 385,
|
| 429 |
+
"上杉_JP": 386,
|
| 430 |
+
"西拉杰_JP": 387,
|
| 431 |
+
"菲尔戈黛特_JP": 388,
|
| 432 |
+
"一平_JP": 389,
|
| 433 |
+
"纯水精灵_JP": 390,
|
| 434 |
+
"阿尔卡米_JP": 391,
|
| 435 |
+
"老戴_JP": 392,
|
| 436 |
+
"谢赫祖拜尔_JP": 393,
|
| 437 |
+
"沙扎曼_JP": 394,
|
| 438 |
+
"田铁嘴_JP": 395,
|
| 439 |
+
"小野寺_JP": 396,
|
| 440 |
+
"百识_JP": 397,
|
| 441 |
+
"克罗索_JP": 398,
|
| 442 |
+
"莱斯格_JP": 399,
|
| 443 |
+
"芷巧_JP": 400,
|
| 444 |
+
"加藤洋平_JP": 401,
|
| 445 |
+
"阿巴图伊_JP": 402,
|
| 446 |
+
"埃尔欣根_JP": 403,
|
| 447 |
+
"斯嘉莉_JP": 404,
|
| 448 |
+
"阿佩普_JP": 405,
|
| 449 |
+
"巫女_JP": 406,
|
| 450 |
+
"卡布斯_JP": 407,
|
| 451 |
+
"洛伦佐_JP": 408,
|
| 452 |
+
"萨赫哈蒂_JP": 409,
|
| 453 |
+
"娜德瓦_JP": 410,
|
| 454 |
+
"塞德娜_JP": 411,
|
| 455 |
+
"塔杰·拉德卡尼_JP": 412,
|
| 456 |
+
"绘星_JP": 413,
|
| 457 |
+
"泽田_JP": 414,
|
| 458 |
+
"安西_JP": 415,
|
| 459 |
+
"拉伊德_JP": 416,
|
| 460 |
+
"亚卡巴_JP": 417,
|
| 461 |
+
"有乐斋_JP": 418,
|
| 462 |
+
"莱昂_JP": 419,
|
| 463 |
+
"尤苏波夫_JP": 420,
|
| 464 |
+
"夏妮_JP": 421,
|
| 465 |
+
"埃舍尔_JP": 422,
|
| 466 |
+
"萨齐因_JP": 423,
|
| 467 |
+
"古山_JP": 424,
|
| 468 |
+
"自称渊上之物_JP": 425,
|
| 469 |
+
"丹羽_JP": 426,
|
| 470 |
+
"塞萨尔的日记_JP": 427,
|
| 471 |
+
"派蒙_EN": 428,
|
| 472 |
+
"纳西妲_EN": 429,
|
| 473 |
+
"凯亚_EN": 430,
|
| 474 |
+
"阿贝多_EN": 431,
|
| 475 |
+
"温迪_EN": 432,
|
| 476 |
+
"枫原万叶_EN": 433,
|
| 477 |
+
"钟离_EN": 434,
|
| 478 |
+
"荒泷一斗_EN": 435,
|
| 479 |
+
"八重神子_EN": 436,
|
| 480 |
+
"艾尔海森_EN": 437,
|
| 481 |
+
"提纳里_EN": 438,
|
| 482 |
+
"迪希雅_EN": 439,
|
| 483 |
+
"卡维_EN": 440,
|
| 484 |
+
"宵宫_EN": 441,
|
| 485 |
+
"莱依拉_EN": 442,
|
| 486 |
+
"那维莱特_EN": 443,
|
| 487 |
+
"赛诺_EN": 444,
|
| 488 |
+
"莫娜_EN": 445,
|
| 489 |
+
"诺艾尔_EN": 446,
|
| 490 |
+
"托马_EN": 447,
|
| 491 |
+
"凝光_EN": 448,
|
| 492 |
+
"林尼_EN": 449,
|
| 493 |
+
"北斗_EN": 450,
|
| 494 |
+
"柯莱_EN": 451,
|
| 495 |
+
"神里绫华_EN": 452,
|
| 496 |
+
"可莉_EN": 453,
|
| 497 |
+
"芭芭拉_EN": 454,
|
| 498 |
+
"雷电将军_EN": 455,
|
| 499 |
+
"娜维娅_EN": 456,
|
| 500 |
+
"芙宁娜_EN": 457,
|
| 501 |
+
"珊瑚宫心海_EN": 458,
|
| 502 |
+
"鹿野院平藏_EN": 459,
|
| 503 |
+
"迪奥娜_EN": 460,
|
| 504 |
+
"五郎_EN": 461,
|
| 505 |
+
"琴_EN": 462,
|
| 506 |
+
"班尼特_EN": 463,
|
| 507 |
+
"达达利亚_EN": 464,
|
| 508 |
+
"安柏_EN": 465,
|
| 509 |
+
"莱欧斯利_EN": 466,
|
| 510 |
+
"夜兰_EN": 467,
|
| 511 |
+
"妮露_EN": 468,
|
| 512 |
+
"辛焱_EN": 469,
|
| 513 |
+
"珐露珊_EN": 470,
|
| 514 |
+
"丽莎_EN": 471,
|
| 515 |
+
"魈_EN": 472,
|
| 516 |
+
"香菱_EN": 473,
|
| 517 |
+
"迪卢克_EN": 474,
|
| 518 |
+
"砂糖_EN": 475,
|
| 519 |
+
"烟绯_EN": 476,
|
| 520 |
+
"早柚_EN": 477,
|
| 521 |
+
"云堇_EN": 478,
|
| 522 |
+
"刻晴_EN": 479,
|
| 523 |
+
"重云_EN": 480,
|
| 524 |
+
"优菈_EN": 481,
|
| 525 |
+
"胡桃_EN": 482,
|
| 526 |
+
"流浪者_EN": 483,
|
| 527 |
+
"久岐忍_EN": 484,
|
| 528 |
+
"神里绫人_EN": 485,
|
| 529 |
+
"甘雨_EN": 486,
|
| 530 |
+
"戴因斯雷布_EN": 487,
|
| 531 |
+
"菲谢尔_EN": 488,
|
| 532 |
+
"白术_EN": 489,
|
| 533 |
+
"行秋_EN": 490,
|
| 534 |
+
"九条裟罗_EN": 491,
|
| 535 |
+
"夏洛蒂_EN": 492,
|
| 536 |
+
"雷泽_EN": 493,
|
| 537 |
+
"申鹤_EN": 494,
|
| 538 |
+
"荧_EN": 495,
|
| 539 |
+
"空_EN": 496,
|
| 540 |
+
"迪娜泽黛_EN": 497,
|
| 541 |
+
"凯瑟琳_EN": 498,
|
| 542 |
+
"多莉_EN": 499,
|
| 543 |
+
"坎蒂丝_EN": 500,
|
| 544 |
+
"琳妮特_EN": 501,
|
| 545 |
+
"萍姥姥_EN": 502,
|
| 546 |
+
"罗莎莉亚_EN": 503,
|
| 547 |
+
"埃德_EN": 504,
|
| 548 |
+
"爱贝尔_EN": 505,
|
| 549 |
+
"伊迪娅_EN": 506,
|
| 550 |
+
"留云借风真君_EN": 507,
|
| 551 |
+
"绮良良_EN": 508,
|
| 552 |
+
"陌生人_EN": 509,
|
| 553 |
+
"七七_EN": 510,
|
| 554 |
+
"式大将_EN": 511,
|
| 555 |
+
"瑶瑶_EN": 512,
|
| 556 |
+
"奥兹_EN": 513,
|
| 557 |
+
"菲米尼_EN": 514,
|
| 558 |
+
"米卡_EN": 515,
|
| 559 |
+
"哲平_EN": 516,
|
| 560 |
+
"浮游水蕈兽·元素生命_EN": 517,
|
| 561 |
+
"大肉丸_EN": 518,
|
| 562 |
+
"托克_EN": 519,
|
| 563 |
+
"蒂玛乌斯_EN": 520,
|
| 564 |
+
"昆钧_EN": 521,
|
| 565 |
+
"欧菲妮_EN": 522,
|
| 566 |
+
"塞琉斯_EN": 523,
|
| 567 |
+
"仆人_EN": 524,
|
| 568 |
+
"迈勒斯_EN": 525,
|
| 569 |
+
"希格雯_EN": 526,
|
| 570 |
+
"阿守_EN": 527,
|
| 571 |
+
"拉赫曼_EN": 528,
|
| 572 |
+
"杜拉夫_EN": 529,
|
| 573 |
+
"伊利亚斯_EN": 530,
|
| 574 |
+
"阿晃_EN": 531,
|
| 575 |
+
"旁白_EN": 532,
|
| 576 |
+
"爱德琳_EN": 533,
|
| 577 |
+
"埃洛伊_EN": 534,
|
| 578 |
+
"德沃沙克_EN": 535,
|
| 579 |
+
"玛乔丽_EN": 536,
|
| 580 |
+
"塞塔蕾_EN": 537,
|
| 581 |
+
"柊千里_EN": 538,
|
| 582 |
+
"海芭夏_EN": 539,
|
| 583 |
+
"九条镰治_EN": 540,
|
| 584 |
+
"阿娜耶_EN": 541,
|
| 585 |
+
"笼钓瓶一心_EN": 542,
|
| 586 |
+
"回声海螺_EN": 543,
|
| 587 |
+
"劳维克_EN": 544,
|
| 588 |
+
"元太_EN": 545,
|
| 589 |
+
"阿扎尔_EN": 546,
|
| 590 |
+
"查尔斯_EN": 547,
|
| 591 |
+
"阿洛瓦_EN": 548,
|
| 592 |
+
"埃勒曼_EN": 549,
|
| 593 |
+
"纳比尔_EN": 550,
|
| 594 |
+
"莎拉_EN": 551,
|
| 595 |
+
"康纳_EN": 552,
|
| 596 |
+
"博来_EN": 553,
|
| 597 |
+
"玛塞勒_EN": 554,
|
| 598 |
+
"阿祇_EN": 555,
|
| 599 |
+
"博士_EN": 556,
|
| 600 |
+
"迪尔菲_EN": 557,
|
| 601 |
+
"宛烟_EN": 558,
|
| 602 |
+
"玛格丽特_EN": 559,
|
| 603 |
+
"羽生田千鹤_EN": 560,
|
| 604 |
+
"海妮耶_EN": 561,
|
| 605 |
+
"霍夫曼_EN": 562,
|
| 606 |
+
"旅行者_EN": 563,
|
| 607 |
+
"佐西摩斯_EN": 564,
|
| 608 |
+
"鹿野奈奈_EN": 565,
|
| 609 |
+
"舒伯特_EN": 566,
|
| 610 |
+
"天叔_EN": 567,
|
| 611 |
+
"艾莉丝_EN": 568,
|
| 612 |
+
"龙二_EN": 569,
|
| 613 |
+
"莺儿_EN": 570,
|
| 614 |
+
"嘉良_EN": 571,
|
| 615 |
+
"珊瑚_EN": 572,
|
| 616 |
+
"费迪南德_EN": 573,
|
| 617 |
+
"言笑_EN": 574,
|
| 618 |
+
"一心传名刀_EN": 575,
|
| 619 |
+
"久利须_EN": 576,
|
| 620 |
+
"嘉玛_EN": 577,
|
| 621 |
+
"艾文_EN": 578,
|
| 622 |
+
"克洛琳德_EN": 579,
|
| 623 |
+
"丹吉尔_EN": 580,
|
| 624 |
+
"女士_EN": 581,
|
| 625 |
+
"天目十五_EN": 582,
|
| 626 |
+
"老孟_EN": 583,
|
| 627 |
+
"白老先生_EN": 584,
|
| 628 |
+
"舍利夫_EN": 585,
|
| 629 |
+
"巴达维_EN": 586,
|
| 630 |
+
"拉齐_EN": 587,
|
| 631 |
+
"长生_EN": 588,
|
| 632 |
+
"吴船长_EN": 589,
|
| 633 |
+
"艾伯特_EN": 590,
|
| 634 |
+
"松浦_EN": 591,
|
| 635 |
+
"埃泽_EN": 592,
|
| 636 |
+
"阿圆_EN": 593,
|
| 637 |
+
"阿拉夫_EN": 594,
|
| 638 |
+
"莫塞伊思_EN": 595,
|
| 639 |
+
"石头_EN": 596,
|
| 640 |
+
"百闻_EN": 597,
|
| 641 |
+
"杜吉耶_EN": 598,
|
| 642 |
+
"波洛_EN": 599,
|
| 643 |
+
"斯坦利_EN": 600,
|
| 644 |
+
"掇星攫辰天君_EN": 601,
|
| 645 |
+
"迈蒙_EN": 602,
|
| 646 |
+
"博易_EN": 603,
|
| 647 |
+
"诗筠_EN": 604,
|
| 648 |
+
"毗伽尔_EN": 605,
|
| 649 |
+
"慧心_EN": 606,
|
| 650 |
+
"芙卡洛斯_EN": 607,
|
| 651 |
+
"恶龙_EN": 608,
|
| 652 |
+
"小仓澪_EN": 609,
|
| 653 |
+
"恕筠_EN": 610,
|
| 654 |
+
"知易_EN": 611,
|
| 655 |
+
"克列门特_EN": 612,
|
| 656 |
+
"大慈树王_EN": 613,
|
| 657 |
+
"维多利亚_EN": 614,
|
| 658 |
+
"黑田_EN": 615,
|
| 659 |
+
"马姆杜_EN": 616,
|
| 660 |
+
"科林斯_EN": 617,
|
| 661 |
+
"上杉_EN": 618,
|
| 662 |
+
"西拉杰_EN": 619,
|
| 663 |
+
"宁禄_EN": 620,
|
| 664 |
+
"纯水精灵_EN": 621,
|
| 665 |
+
"常九爷_EN": 622,
|
| 666 |
+
"阿尔卡米_EN": 623,
|
| 667 |
+
"沙扎曼_EN": 624,
|
| 668 |
+
"田铁嘴_EN": 625,
|
| 669 |
+
"加萨尼_EN": 626,
|
| 670 |
+
"克罗索_EN": 627,
|
| 671 |
+
"星稀_EN": 628,
|
| 672 |
+
"莱斯格_EN": 629,
|
| 673 |
+
"阿巴图伊_EN": 630,
|
| 674 |
+
"埃尔欣根_EN": 631,
|
| 675 |
+
"阿佩普_EN": 632,
|
| 676 |
+
"萨赫哈蒂_EN": 633,
|
| 677 |
+
"洛伦佐_EN": 634,
|
| 678 |
+
"塔杰·拉德卡尼_EN": 635,
|
| 679 |
+
"泽田_EN": 636,
|
| 680 |
+
"安西_EN": 637,
|
| 681 |
"埃舍尔_EN": 638,
|
| 682 |
+
"三月七_ZH": 639,
|
| 683 |
+
"丹恒_ZH": 640,
|
| 684 |
+
"希儿_ZH": 641,
|
| 685 |
+
"娜塔莎_ZH": 642,
|
| 686 |
+
"希露瓦_ZH": 643,
|
| 687 |
+
"瓦尔特_ZH": 644,
|
| 688 |
+
"佩拉_ZH": 645,
|
| 689 |
+
"布洛妮娅_ZH": 646,
|
| 690 |
+
"虎克_ZH": 647,
|
| 691 |
+
"素裳_ZH": 648,
|
| 692 |
+
"克拉拉_ZH": 649,
|
| 693 |
+
"符玄_ZH": 650,
|
| 694 |
+
"白露_ZH": 651,
|
| 695 |
+
"杰帕德_ZH": 652,
|
| 696 |
+
"景元_ZH": 653,
|
| 697 |
+
"藿藿_ZH": 654,
|
| 698 |
+
"姬子_ZH": 655,
|
| 699 |
+
"穹_ZH": 656,
|
| 700 |
+
"星_ZH": 657,
|
| 701 |
+
"卡芙卡_ZH": 658,
|
| 702 |
+
"桂乃芬_ZH": 659,
|
| 703 |
+
"艾丝妲_ZH": 660,
|
| 704 |
+
"玲可_ZH": 661,
|
| 705 |
+
"彦卿_ZH": 662,
|
| 706 |
+
"托帕_ZH": 663,
|
| 707 |
+
"驭空_ZH": 664,
|
| 708 |
+
"浮烟_ZH": 665,
|
| 709 |
+
"停云_ZH": 666,
|
| 710 |
+
"镜流_ZH": 667,
|
| 711 |
+
"罗刹_ZH": 668,
|
| 712 |
+
"卢卡_ZH": 669,
|
| 713 |
+
"史瓦罗_ZH": 670,
|
| 714 |
+
"黑塔_ZH": 671,
|
| 715 |
+
"桑博_ZH": 672,
|
| 716 |
+
"伦纳德_ZH": 673,
|
| 717 |
+
"明曦_ZH": 674,
|
| 718 |
+
"银狼_ZH": 675,
|
| 719 |
+
"帕姆_ZH": 676,
|
| 720 |
+
"青雀_ZH": 677,
|
| 721 |
+
"乔瓦尼_ZH": 678,
|
| 722 |
+
"公输师傅_ZH": 679,
|
| 723 |
+
"晴霓_ZH": 680,
|
| 724 |
+
"螺丝咕姆_ZH": 681,
|
| 725 |
+
"阿兰_ZH": 682,
|
| 726 |
+
"奥列格_ZH": 683,
|
| 727 |
+
"丹枢_ZH": 684,
|
| 728 |
+
"尾巴_ZH": 685,
|
| 729 |
+
"寒鸦_ZH": 686,
|
| 730 |
+
"雪衣_ZH": 687,
|
| 731 |
+
"可可利亚_ZH": 688,
|
| 732 |
+
"青镞_ZH": 689,
|
| 733 |
+
"半夏_ZH": 690,
|
| 734 |
+
"银枝_ZH": 691,
|
| 735 |
+
"大毫_ZH": 692,
|
| 736 |
+
"霄翰_ZH": 693,
|
| 737 |
+
"信使_ZH": 694,
|
| 738 |
+
"费斯曼_ZH": 695,
|
| 739 |
+
"绿芙蓉_ZH": 696,
|
| 740 |
+
"金人会长_ZH": 697,
|
| 741 |
+
"维利特_ZH": 698,
|
| 742 |
+
"维尔德_ZH": 699,
|
| 743 |
+
"斯科特_ZH": 700,
|
| 744 |
+
"卡波特_ZH": 701,
|
| 745 |
+
"刃_ZH": 702,
|
| 746 |
+
"岩明_ZH": 703,
|
| 747 |
+
"浣溪_ZH": 704,
|
| 748 |
+
"三月七_JP": 705,
|
| 749 |
+
"丹恒_JP": 706,
|
| 750 |
+
"希儿_JP": 707,
|
| 751 |
+
"娜塔莎_JP": 708,
|
| 752 |
+
"希露瓦_JP": 709,
|
| 753 |
+
"瓦尔特_JP": 710,
|
| 754 |
+
"佩拉_JP": 711,
|
| 755 |
+
"布洛妮娅_JP": 712,
|
| 756 |
+
"虎克_JP": 713,
|
| 757 |
+
"素裳_JP": 714,
|
| 758 |
+
"克拉拉_JP": 715,
|
| 759 |
+
"符玄_JP": 716,
|
| 760 |
+
"白露_JP": 717,
|
| 761 |
+
"杰帕德_JP": 718,
|
| 762 |
+
"景元_JP": 719,
|
| 763 |
+
"藿藿_JP": 720,
|
| 764 |
+
"姬子_JP": 721,
|
| 765 |
+
"卡芙卡_JP": 722,
|
| 766 |
+
"穹_JP": 723,
|
| 767 |
+
"星_JP": 724,
|
| 768 |
+
"桂乃芬_JP": 725,
|
| 769 |
+
"艾丝妲_JP": 726,
|
| 770 |
+
"彦卿_JP": 727,
|
| 771 |
+
"玲可_JP": 728,
|
| 772 |
+
"托帕_JP": 729,
|
| 773 |
+
"驭空_JP": 730,
|
| 774 |
+
"浮烟_JP": 731,
|
| 775 |
+
"停云_JP": 732,
|
| 776 |
+
"镜流_JP": 733,
|
| 777 |
+
"罗刹_JP": 734,
|
| 778 |
+
"卢卡_JP": 735,
|
| 779 |
+
"史瓦罗_JP": 736,
|
| 780 |
+
"黑塔_JP": 737,
|
| 781 |
+
"桑博_JP": 738,
|
| 782 |
+
"伦纳德_JP": 739,
|
| 783 |
+
"明曦_JP": 740,
|
| 784 |
+
"银狼_JP": 741,
|
| 785 |
+
"帕姆_JP": 742,
|
| 786 |
+
"青雀_JP": 743,
|
| 787 |
+
"乔瓦尼_JP": 744,
|
| 788 |
+
"公输师傅_JP": 745,
|
| 789 |
+
"晴霓_JP": 746,
|
| 790 |
+
"螺丝咕姆_JP": 747,
|
| 791 |
+
"阿兰_JP": 748,
|
| 792 |
+
"奥列格_JP": 749,
|
| 793 |
+
"丹枢_JP": 750,
|
| 794 |
+
"尾巴_JP": 751,
|
| 795 |
+
"寒鸦_JP": 752,
|
| 796 |
+
"雪衣_JP": 753,
|
| 797 |
+
"可可利亚_JP": 754,
|
| 798 |
+
"青镞_JP": 755,
|
| 799 |
+
"半夏_JP": 756,
|
| 800 |
+
"银枝_JP": 757,
|
| 801 |
+
"大毫_JP": 758,
|
| 802 |
+
"霄翰_JP": 759,
|
| 803 |
+
"信使_JP": 760,
|
| 804 |
+
"费斯曼_JP": 761,
|
| 805 |
+
"绿芙蓉_JP": 762,
|
| 806 |
+
"金人会长_JP": 763,
|
| 807 |
+
"维利特_JP": 764,
|
| 808 |
+
"维尔德_JP": 765,
|
| 809 |
+
"斯科特_JP": 766,
|
| 810 |
+
"刃_JP": 767,
|
| 811 |
+
"卡波特_JP": 768,
|
| 812 |
+
"岩明_JP": 769,
|
| 813 |
+
"浣溪_JP": 770,
|
| 814 |
+
"净砚_JP": 771,
|
| 815 |
+
"紫月季_JP": 772,
|
| 816 |
+
"歌蒂_JP": 773,
|
| 817 |
+
"奇怪的云骑_JP": 774,
|
| 818 |
+
"幻胧_JP": 775,
|
| 819 |
+
"斯薇塔_JP": 776,
|
| 820 |
+
"隐书_JP": 777,
|
| 821 |
+
"三月七_EN": 778,
|
| 822 |
+
"丹恒_EN": 779,
|
| 823 |
+
"希儿_EN": 780,
|
| 824 |
+
"娜塔莎_EN": 781,
|
| 825 |
+
"希露瓦_EN": 782,
|
| 826 |
+
"瓦尔特_EN": 783,
|
| 827 |
+
"佩拉_EN": 784,
|
| 828 |
+
"布洛妮娅_EN": 785,
|
| 829 |
+
"虎克_EN": 786,
|
| 830 |
+
"素裳_EN": 787,
|
| 831 |
+
"克拉拉_EN": 788,
|
| 832 |
+
"符玄_EN": 789,
|
| 833 |
+
"白露_EN": 790,
|
| 834 |
+
"杰帕德_EN": 791,
|
| 835 |
+
"景元_EN": 792,
|
| 836 |
+
"藿藿_EN": 793,
|
| 837 |
+
"姬子_EN": 794,
|
| 838 |
+
"卡芙卡_EN": 795,
|
| 839 |
+
"穹_EN": 796,
|
| 840 |
+
"星_EN": 797,
|
| 841 |
+
"桂乃芬_EN": 798,
|
| 842 |
+
"艾丝妲_EN": 799,
|
| 843 |
+
"彦卿_EN": 800,
|
| 844 |
+
"玲可_EN": 801,
|
| 845 |
+
"托帕_EN": 802,
|
| 846 |
+
"驭空_EN": 803,
|
| 847 |
+
"浮烟_EN": 804,
|
| 848 |
+
"停云_EN": 805,
|
| 849 |
+
"镜流_EN": 806,
|
| 850 |
+
"罗刹_EN": 807,
|
| 851 |
+
"卢卡_EN": 808,
|
| 852 |
+
"史瓦罗_EN": 809,
|
| 853 |
+
"黑塔_EN": 810,
|
| 854 |
+
"桑博_EN": 811,
|
| 855 |
+
"伦纳德_EN": 812,
|
| 856 |
+
"明曦_EN": 813,
|
| 857 |
+
"银狼_EN": 814,
|
| 858 |
+
"帕姆_EN": 815,
|
| 859 |
+
"青雀_EN": 816,
|
| 860 |
+
"乔瓦尼_EN": 817,
|
| 861 |
+
"公输师傅_EN": 818,
|
| 862 |
+
"晴霓_EN": 819,
|
| 863 |
+
"螺丝咕姆_EN": 820,
|
| 864 |
+
"阿兰_EN": 821,
|
| 865 |
+
"奥列格_EN": 822,
|
| 866 |
+
"丹枢_EN": 823,
|
| 867 |
+
"尾巴_EN": 824,
|
| 868 |
+
"寒鸦_EN": 825,
|
| 869 |
+
"雪衣_EN": 826,
|
| 870 |
+
"可可利亚_EN": 827,
|
| 871 |
+
"青镞_EN": 828,
|
| 872 |
+
"半夏_EN": 829,
|
| 873 |
+
"银枝_EN": 830,
|
| 874 |
+
"大毫_EN": 831,
|
| 875 |
+
"霄翰_EN": 832,
|
| 876 |
+
"信使_EN": 833,
|
| 877 |
+
"费斯曼_EN": 834,
|
| 878 |
+
"绿芙蓉_EN": 835,
|
| 879 |
+
"金人会长_EN": 836,
|
| 880 |
+
"维利特_EN": 837,
|
| 881 |
+
"维尔德_EN": 838,
|
| 882 |
+
"刃_EN": 839,
|
| 883 |
+
"卡波特_EN": 840,
|
| 884 |
+
"岩明_EN": 841,
|
| 885 |
+
"浣溪_EN": 842,
|
| 886 |
+
"紫月季_EN": 843,
|
| 887 |
+
"幻胧_EN": 844,
|
| 888 |
+
"女声_EN": 845,
|
| 889 |
+
"陆景和": 846,
|
| 890 |
+
"莫弈": 847,
|
| 891 |
+
"左然": 848,
|
| 892 |
+
"夏彦": 849
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 893 |
}
|
| 894 |
},
|
| 895 |
"model": {
|
|
|
|
| 944 |
],
|
| 945 |
"n_layers_q": 3,
|
| 946 |
"use_spectral_norm": false,
|
| 947 |
+
"gin_channels": 512,
|
| 948 |
+
"slm": {
|
| 949 |
+
"model": "./slm/wavlm-base-plus",
|
| 950 |
+
"sr": 16000,
|
| 951 |
+
"hidden": 768,
|
| 952 |
+
"nlayers": 13,
|
| 953 |
+
"initial_channel": 64
|
| 954 |
+
}
|
| 955 |
},
|
| 956 |
+
"version": "2.3"
|
| 957 |
}
|
data_utils.py
CHANGED
|
@@ -44,10 +44,6 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
|
| 44 |
self.min_text_len = getattr(hparams, "min_text_len", 1)
|
| 45 |
self.max_text_len = getattr(hparams, "max_text_len", 384)
|
| 46 |
|
| 47 |
-
self.empty_emo = torch.squeeze(
|
| 48 |
-
torch.load("empty_emo.npy", map_location="cpu"), dim=1
|
| 49 |
-
)
|
| 50 |
-
|
| 51 |
random.seed(1234)
|
| 52 |
random.shuffle(self.audiopaths_sid_text)
|
| 53 |
self._filter()
|
|
@@ -98,14 +94,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
|
| 98 |
spec, wav = self.get_audio(audiopath)
|
| 99 |
sid = torch.LongTensor([int(self.spk_map[sid])])
|
| 100 |
|
| 101 |
-
|
| 102 |
-
emo = torch.squeeze(
|
| 103 |
-
torch.load(audiopath.replace(".wav", ".emo.npy"), map_location="cpu"),
|
| 104 |
-
dim=1,
|
| 105 |
-
)
|
| 106 |
-
else:
|
| 107 |
-
emo = self.empty_emo
|
| 108 |
-
return (phones, spec, wav, sid, tone, language, bert, ja_bert, en_bert, emo)
|
| 109 |
|
| 110 |
def get_audio(self, filename):
|
| 111 |
audio, sampling_rate = load_wav_to_torch(filename)
|
|
@@ -168,15 +157,15 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
|
| 168 |
|
| 169 |
if language_str == "ZH":
|
| 170 |
bert = bert_ori
|
| 171 |
-
ja_bert = torch.
|
| 172 |
-
en_bert = torch.
|
| 173 |
elif language_str == "JP":
|
| 174 |
-
bert = torch.
|
| 175 |
ja_bert = bert_ori
|
| 176 |
-
en_bert = torch.
|
| 177 |
elif language_str == "EN":
|
| 178 |
-
bert = torch.
|
| 179 |
-
ja_bert = torch.
|
| 180 |
en_bert = bert_ori
|
| 181 |
phone = torch.LongTensor(phone)
|
| 182 |
tone = torch.LongTensor(tone)
|
|
@@ -226,7 +215,6 @@ class TextAudioSpeakerCollate:
|
|
| 226 |
bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
|
| 227 |
ja_bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
|
| 228 |
en_bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
|
| 229 |
-
emo = torch.FloatTensor(len(batch), 512)
|
| 230 |
|
| 231 |
spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
|
| 232 |
wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
|
|
@@ -238,7 +226,6 @@ class TextAudioSpeakerCollate:
|
|
| 238 |
bert_padded.zero_()
|
| 239 |
ja_bert_padded.zero_()
|
| 240 |
en_bert_padded.zero_()
|
| 241 |
-
emo.zero_()
|
| 242 |
|
| 243 |
for i in range(len(ids_sorted_decreasing)):
|
| 244 |
row = batch[ids_sorted_decreasing[i]]
|
|
@@ -272,8 +259,6 @@ class TextAudioSpeakerCollate:
|
|
| 272 |
en_bert = row[8]
|
| 273 |
en_bert_padded[i, :, : en_bert.size(1)] = en_bert
|
| 274 |
|
| 275 |
-
emo[i, :] = row[9]
|
| 276 |
-
|
| 277 |
return (
|
| 278 |
text_padded,
|
| 279 |
text_lengths,
|
|
@@ -287,7 +272,6 @@ class TextAudioSpeakerCollate:
|
|
| 287 |
bert_padded,
|
| 288 |
ja_bert_padded,
|
| 289 |
en_bert_padded,
|
| 290 |
-
emo,
|
| 291 |
)
|
| 292 |
|
| 293 |
|
|
|
|
| 44 |
self.min_text_len = getattr(hparams, "min_text_len", 1)
|
| 45 |
self.max_text_len = getattr(hparams, "max_text_len", 384)
|
| 46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
random.seed(1234)
|
| 48 |
random.shuffle(self.audiopaths_sid_text)
|
| 49 |
self._filter()
|
|
|
|
| 94 |
spec, wav = self.get_audio(audiopath)
|
| 95 |
sid = torch.LongTensor([int(self.spk_map[sid])])
|
| 96 |
|
| 97 |
+
return (phones, spec, wav, sid, tone, language, bert, ja_bert, en_bert)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
def get_audio(self, filename):
|
| 100 |
audio, sampling_rate = load_wav_to_torch(filename)
|
|
|
|
| 157 |
|
| 158 |
if language_str == "ZH":
|
| 159 |
bert = bert_ori
|
| 160 |
+
ja_bert = torch.randn(1024, len(phone))
|
| 161 |
+
en_bert = torch.randn(1024, len(phone))
|
| 162 |
elif language_str == "JP":
|
| 163 |
+
bert = torch.randn(1024, len(phone))
|
| 164 |
ja_bert = bert_ori
|
| 165 |
+
en_bert = torch.randn(1024, len(phone))
|
| 166 |
elif language_str == "EN":
|
| 167 |
+
bert = torch.randn(1024, len(phone))
|
| 168 |
+
ja_bert = torch.randn(1024, len(phone))
|
| 169 |
en_bert = bert_ori
|
| 170 |
phone = torch.LongTensor(phone)
|
| 171 |
tone = torch.LongTensor(tone)
|
|
|
|
| 215 |
bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
|
| 216 |
ja_bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
|
| 217 |
en_bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
|
|
|
|
| 218 |
|
| 219 |
spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
|
| 220 |
wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
|
|
|
|
| 226 |
bert_padded.zero_()
|
| 227 |
ja_bert_padded.zero_()
|
| 228 |
en_bert_padded.zero_()
|
|
|
|
| 229 |
|
| 230 |
for i in range(len(ids_sorted_decreasing)):
|
| 231 |
row = batch[ids_sorted_decreasing[i]]
|
|
|
|
| 259 |
en_bert = row[8]
|
| 260 |
en_bert_padded[i, :, : en_bert.size(1)] = en_bert
|
| 261 |
|
|
|
|
|
|
|
| 262 |
return (
|
| 263 |
text_padded,
|
| 264 |
text_lengths,
|
|
|
|
| 272 |
bert_padded,
|
| 273 |
ja_bert_padded,
|
| 274 |
en_bert_padded,
|
|
|
|
| 275 |
)
|
| 276 |
|
| 277 |
|
default_config.yml
CHANGED
|
@@ -83,11 +83,11 @@ train_ms:
|
|
| 83 |
base:
|
| 84 |
use_base_model: false
|
| 85 |
repo_id: "Stardust_minus/Bert-VITS2"
|
| 86 |
-
model_image: "Bert-VITS2_2.
|
| 87 |
# 训练模型存储目录:与旧版本的区别,原先数据集是存放在logs/model_name下的,现在改为统一存放在Data/你的数据集/models下
|
| 88 |
model: "models"
|
| 89 |
# 配置文件路径
|
| 90 |
-
config_path: "
|
| 91 |
# 训练使用的worker,不建议超过CPU核心数
|
| 92 |
num_workers: 16
|
| 93 |
# 关闭此项可以节约接近50%的磁盘空间,但是可能导致实际训练速度变慢和更高的CPU使用率。
|
|
@@ -104,7 +104,7 @@ webui:
|
|
| 104 |
# 模型路径
|
| 105 |
model: "models/G_8000.pth"
|
| 106 |
# 配置文件路径
|
| 107 |
-
config_path: "
|
| 108 |
# 端口号
|
| 109 |
port: 7860
|
| 110 |
# 是否公开部署,对外网开放
|
|
|
|
| 83 |
base:
|
| 84 |
use_base_model: false
|
| 85 |
repo_id: "Stardust_minus/Bert-VITS2"
|
| 86 |
+
model_image: "Bert-VITS2_2.3底模" # openi网页的模型名
|
| 87 |
# 训练模型存储目录:与旧版本的区别,原先数据集是存放在logs/model_name下的,现在改为统一存放在Data/你的数据集/models下
|
| 88 |
model: "models"
|
| 89 |
# 配置文件路径
|
| 90 |
+
config_path: "config.json"
|
| 91 |
# 训练使用的worker,不建议超过CPU核心数
|
| 92 |
num_workers: 16
|
| 93 |
# 关闭此项可以节约接近50%的磁盘空间,但是可能导致实际训练速度变慢和更高的CPU使用率。
|
|
|
|
| 104 |
# 模型路径
|
| 105 |
model: "models/G_8000.pth"
|
| 106 |
# 配置文件路径
|
| 107 |
+
config_path: "config.json"
|
| 108 |
# 端口号
|
| 109 |
port: 7860
|
| 110 |
# 是否公开部署,对外网开放
|
export_onnx.py
CHANGED
|
@@ -2,11 +2,13 @@ from onnx_modules import export_onnx
|
|
| 2 |
import os
|
| 3 |
|
| 4 |
if __name__ == "__main__":
|
| 5 |
-
export_path = "
|
| 6 |
-
model_path = "
|
| 7 |
-
config_path = "
|
|
|
|
|
|
|
| 8 |
if not os.path.exists("onnx"):
|
| 9 |
os.makedirs("onnx")
|
| 10 |
if not os.path.exists(f"onnx/{export_path}"):
|
| 11 |
os.makedirs(f"onnx/{export_path}")
|
| 12 |
-
export_onnx(export_path, model_path, config_path)
|
|
|
|
| 2 |
import os
|
| 3 |
|
| 4 |
if __name__ == "__main__":
|
| 5 |
+
export_path = "BertVits2.2PT"
|
| 6 |
+
model_path = "model\\G_0.pth"
|
| 7 |
+
config_path = "model\\config.json"
|
| 8 |
+
novq = False
|
| 9 |
+
dev = False
|
| 10 |
if not os.path.exists("onnx"):
|
| 11 |
os.makedirs("onnx")
|
| 12 |
if not os.path.exists(f"onnx/{export_path}"):
|
| 13 |
os.makedirs(f"onnx/{export_path}")
|
| 14 |
+
export_onnx(export_path, model_path, config_path, novq, dev)
|
for_deploy/infer.py
ADDED
|
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
版本管理、兼容推理及模型加载实现。
|
| 3 |
+
版本说明:
|
| 4 |
+
1. 版本号与github的release版本号对应,使用哪个release版本训练的模型即对应其版本号
|
| 5 |
+
2. 请在模型的config.json中显示声明版本号,添加一个字段"version" : "你的版本号"
|
| 6 |
+
特殊版本说明:
|
| 7 |
+
1.1.1-fix: 1.1.1版本训练的模型,但是在推理时使用dev的日语修复
|
| 8 |
+
2.2:当前版本
|
| 9 |
+
"""
|
| 10 |
+
import torch
|
| 11 |
+
import commons
|
| 12 |
+
from text import cleaned_text_to_sequence, get_bert
|
| 13 |
+
from clap_wrapper import get_clap_audio_feature, get_clap_text_feature
|
| 14 |
+
from text.cleaner import clean_text
|
| 15 |
+
import utils
|
| 16 |
+
import numpy as np
|
| 17 |
+
|
| 18 |
+
from models import SynthesizerTrn
|
| 19 |
+
from text.symbols import symbols
|
| 20 |
+
|
| 21 |
+
from oldVersion.V210.models import SynthesizerTrn as V210SynthesizerTrn
|
| 22 |
+
from oldVersion.V210.text import symbols as V210symbols
|
| 23 |
+
from oldVersion.V200.models import SynthesizerTrn as V200SynthesizerTrn
|
| 24 |
+
from oldVersion.V200.text import symbols as V200symbols
|
| 25 |
+
from oldVersion.V111.models import SynthesizerTrn as V111SynthesizerTrn
|
| 26 |
+
from oldVersion.V111.text import symbols as V111symbols
|
| 27 |
+
from oldVersion.V110.models import SynthesizerTrn as V110SynthesizerTrn
|
| 28 |
+
from oldVersion.V110.text import symbols as V110symbols
|
| 29 |
+
from oldVersion.V101.models import SynthesizerTrn as V101SynthesizerTrn
|
| 30 |
+
from oldVersion.V101.text import symbols as V101symbols
|
| 31 |
+
|
| 32 |
+
from oldVersion import V111, V110, V101, V200, V210
|
| 33 |
+
|
| 34 |
+
# 当前版本信息
|
| 35 |
+
latest_version = "2.2"
|
| 36 |
+
|
| 37 |
+
# 版本兼容
|
| 38 |
+
SynthesizerTrnMap = {
|
| 39 |
+
"2.1": V210SynthesizerTrn,
|
| 40 |
+
"2.0.2-fix": V200SynthesizerTrn,
|
| 41 |
+
"2.0.1": V200SynthesizerTrn,
|
| 42 |
+
"2.0": V200SynthesizerTrn,
|
| 43 |
+
"1.1.1-fix": V111SynthesizerTrn,
|
| 44 |
+
"1.1.1": V111SynthesizerTrn,
|
| 45 |
+
"1.1": V110SynthesizerTrn,
|
| 46 |
+
"1.1.0": V110SynthesizerTrn,
|
| 47 |
+
"1.0.1": V101SynthesizerTrn,
|
| 48 |
+
"1.0": V101SynthesizerTrn,
|
| 49 |
+
"1.0.0": V101SynthesizerTrn,
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
symbolsMap = {
|
| 53 |
+
"2.1": V210symbols,
|
| 54 |
+
"2.0.2-fix": V200symbols,
|
| 55 |
+
"2.0.1": V200symbols,
|
| 56 |
+
"2.0": V200symbols,
|
| 57 |
+
"1.1.1-fix": V111symbols,
|
| 58 |
+
"1.1.1": V111symbols,
|
| 59 |
+
"1.1": V110symbols,
|
| 60 |
+
"1.1.0": V110symbols,
|
| 61 |
+
"1.0.1": V101symbols,
|
| 62 |
+
"1.0": V101symbols,
|
| 63 |
+
"1.0.0": V101symbols,
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# def get_emo_(reference_audio, emotion, sid):
|
| 68 |
+
# emo = (
|
| 69 |
+
# torch.from_numpy(get_emo(reference_audio))
|
| 70 |
+
# if reference_audio and emotion == -1
|
| 71 |
+
# else torch.FloatTensor(
|
| 72 |
+
# np.load(f"emo_clustering/{sid}/cluster_center_{emotion}.npy")
|
| 73 |
+
# )
|
| 74 |
+
# )
|
| 75 |
+
# return emo
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def get_net_g(model_path: str, version: str, device: str, hps):
|
| 79 |
+
if version != latest_version:
|
| 80 |
+
net_g = SynthesizerTrnMap[version](
|
| 81 |
+
len(symbolsMap[version]),
|
| 82 |
+
hps.data.filter_length // 2 + 1,
|
| 83 |
+
hps.train.segment_size // hps.data.hop_length,
|
| 84 |
+
n_speakers=hps.data.n_speakers,
|
| 85 |
+
**hps.model,
|
| 86 |
+
).to(device)
|
| 87 |
+
else:
|
| 88 |
+
# 当前版本模型 net_g
|
| 89 |
+
net_g = SynthesizerTrn(
|
| 90 |
+
len(symbols),
|
| 91 |
+
hps.data.filter_length // 2 + 1,
|
| 92 |
+
hps.train.segment_size // hps.data.hop_length,
|
| 93 |
+
n_speakers=hps.data.n_speakers,
|
| 94 |
+
**hps.model,
|
| 95 |
+
).to(device)
|
| 96 |
+
_ = net_g.eval()
|
| 97 |
+
_ = utils.load_checkpoint(model_path, net_g, None, skip_optimizer=True)
|
| 98 |
+
return net_g
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def get_text(text, language_str, bert, hps, device):
|
| 102 |
+
# 在此处实现当前版本的get_text
|
| 103 |
+
norm_text, phone, tone, word2ph = clean_text(text, language_str)
|
| 104 |
+
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
|
| 105 |
+
|
| 106 |
+
if hps.data.add_blank:
|
| 107 |
+
phone = commons.intersperse(phone, 0)
|
| 108 |
+
tone = commons.intersperse(tone, 0)
|
| 109 |
+
language = commons.intersperse(language, 0)
|
| 110 |
+
for i in range(len(word2ph)):
|
| 111 |
+
word2ph[i] = word2ph[i] * 2
|
| 112 |
+
word2ph[0] += 1
|
| 113 |
+
# bert_ori = get_bert(norm_text, word2ph, language_str, device)
|
| 114 |
+
bert_ori = bert[language_str].get_bert_feature(norm_text, word2ph, device)
|
| 115 |
+
del word2ph
|
| 116 |
+
assert bert_ori.shape[-1] == len(phone), phone
|
| 117 |
+
|
| 118 |
+
if language_str == "ZH":
|
| 119 |
+
bert = bert_ori
|
| 120 |
+
ja_bert = torch.randn(1024, len(phone))
|
| 121 |
+
en_bert = torch.randn(1024, len(phone))
|
| 122 |
+
elif language_str == "JP":
|
| 123 |
+
bert = torch.randn(1024, len(phone))
|
| 124 |
+
ja_bert = bert_ori
|
| 125 |
+
en_bert = torch.randn(1024, len(phone))
|
| 126 |
+
elif language_str == "EN":
|
| 127 |
+
bert = torch.randn(1024, len(phone))
|
| 128 |
+
ja_bert = torch.randn(1024, len(phone))
|
| 129 |
+
en_bert = bert_ori
|
| 130 |
+
else:
|
| 131 |
+
raise ValueError("language_str should be ZH, JP or EN")
|
| 132 |
+
|
| 133 |
+
assert bert.shape[-1] == len(
|
| 134 |
+
phone
|
| 135 |
+
), f"Bert seq len {bert.shape[-1]} != {len(phone)}"
|
| 136 |
+
|
| 137 |
+
phone = torch.LongTensor(phone)
|
| 138 |
+
tone = torch.LongTensor(tone)
|
| 139 |
+
language = torch.LongTensor(language)
|
| 140 |
+
return bert, ja_bert, en_bert, phone, tone, language
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def infer(
|
| 144 |
+
text,
|
| 145 |
+
emotion,
|
| 146 |
+
sdp_ratio,
|
| 147 |
+
noise_scale,
|
| 148 |
+
noise_scale_w,
|
| 149 |
+
length_scale,
|
| 150 |
+
sid,
|
| 151 |
+
language,
|
| 152 |
+
hps,
|
| 153 |
+
net_g,
|
| 154 |
+
device,
|
| 155 |
+
bert=None,
|
| 156 |
+
clap=None,
|
| 157 |
+
reference_audio=None,
|
| 158 |
+
skip_start=False,
|
| 159 |
+
skip_end=False,
|
| 160 |
+
):
|
| 161 |
+
# 2.2版本参数位置变了
|
| 162 |
+
# 2.1 参数新增 emotion reference_audio skip_start skip_end
|
| 163 |
+
inferMap_V3 = {
|
| 164 |
+
"2.1": V210.infer,
|
| 165 |
+
}
|
| 166 |
+
# 支持中日英三语版本
|
| 167 |
+
inferMap_V2 = {
|
| 168 |
+
"2.0.2-fix": V200.infer,
|
| 169 |
+
"2.0.1": V200.infer,
|
| 170 |
+
"2.0": V200.infer,
|
| 171 |
+
"1.1.1-fix": V111.infer_fix,
|
| 172 |
+
"1.1.1": V111.infer,
|
| 173 |
+
"1.1": V110.infer,
|
| 174 |
+
"1.1.0": V110.infer,
|
| 175 |
+
}
|
| 176 |
+
# 仅支持中文版本
|
| 177 |
+
# 在测试中,并未发现两个版本的模型不能互相通用
|
| 178 |
+
inferMap_V1 = {
|
| 179 |
+
"1.0.1": V101.infer,
|
| 180 |
+
"1.0": V101.infer,
|
| 181 |
+
"1.0.0": V101.infer,
|
| 182 |
+
}
|
| 183 |
+
version = hps.version if hasattr(hps, "version") else latest_version
|
| 184 |
+
# 非当前版本,根据版本号选择合适的infer
|
| 185 |
+
if version != latest_version:
|
| 186 |
+
if version in inferMap_V3.keys():
|
| 187 |
+
return inferMap_V3[version](
|
| 188 |
+
text,
|
| 189 |
+
sdp_ratio,
|
| 190 |
+
noise_scale,
|
| 191 |
+
noise_scale_w,
|
| 192 |
+
length_scale,
|
| 193 |
+
sid,
|
| 194 |
+
language,
|
| 195 |
+
hps,
|
| 196 |
+
net_g,
|
| 197 |
+
device,
|
| 198 |
+
reference_audio,
|
| 199 |
+
emotion,
|
| 200 |
+
skip_start,
|
| 201 |
+
skip_end,
|
| 202 |
+
)
|
| 203 |
+
if version in inferMap_V2.keys():
|
| 204 |
+
return inferMap_V2[version](
|
| 205 |
+
text,
|
| 206 |
+
sdp_ratio,
|
| 207 |
+
noise_scale,
|
| 208 |
+
noise_scale_w,
|
| 209 |
+
length_scale,
|
| 210 |
+
sid,
|
| 211 |
+
language,
|
| 212 |
+
hps,
|
| 213 |
+
net_g,
|
| 214 |
+
device,
|
| 215 |
+
)
|
| 216 |
+
if version in inferMap_V1.keys():
|
| 217 |
+
return inferMap_V1[version](
|
| 218 |
+
text,
|
| 219 |
+
sdp_ratio,
|
| 220 |
+
noise_scale,
|
| 221 |
+
noise_scale_w,
|
| 222 |
+
length_scale,
|
| 223 |
+
sid,
|
| 224 |
+
hps,
|
| 225 |
+
net_g,
|
| 226 |
+
device,
|
| 227 |
+
)
|
| 228 |
+
# 在此处实现当前版本的推理
|
| 229 |
+
# emo = get_emo_(reference_audio, emotion, sid)
|
| 230 |
+
if isinstance(reference_audio, np.ndarray):
|
| 231 |
+
emo = clap.get_clap_audio_feature(reference_audio, device)
|
| 232 |
+
else:
|
| 233 |
+
emo = clap.get_clap_text_feature(emotion, device)
|
| 234 |
+
emo = torch.squeeze(emo, dim=1)
|
| 235 |
+
|
| 236 |
+
bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
|
| 237 |
+
text, language, bert, hps, device
|
| 238 |
+
)
|
| 239 |
+
if skip_start:
|
| 240 |
+
phones = phones[3:]
|
| 241 |
+
tones = tones[3:]
|
| 242 |
+
lang_ids = lang_ids[3:]
|
| 243 |
+
bert = bert[:, 3:]
|
| 244 |
+
ja_bert = ja_bert[:, 3:]
|
| 245 |
+
en_bert = en_bert[:, 3:]
|
| 246 |
+
if skip_end:
|
| 247 |
+
phones = phones[:-2]
|
| 248 |
+
tones = tones[:-2]
|
| 249 |
+
lang_ids = lang_ids[:-2]
|
| 250 |
+
bert = bert[:, :-2]
|
| 251 |
+
ja_bert = ja_bert[:, :-2]
|
| 252 |
+
en_bert = en_bert[:, :-2]
|
| 253 |
+
with torch.no_grad():
|
| 254 |
+
x_tst = phones.to(device).unsqueeze(0)
|
| 255 |
+
tones = tones.to(device).unsqueeze(0)
|
| 256 |
+
lang_ids = lang_ids.to(device).unsqueeze(0)
|
| 257 |
+
bert = bert.to(device).unsqueeze(0)
|
| 258 |
+
ja_bert = ja_bert.to(device).unsqueeze(0)
|
| 259 |
+
en_bert = en_bert.to(device).unsqueeze(0)
|
| 260 |
+
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
|
| 261 |
+
emo = emo.to(device).unsqueeze(0)
|
| 262 |
+
del phones
|
| 263 |
+
speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
|
| 264 |
+
audio = (
|
| 265 |
+
net_g.infer(
|
| 266 |
+
x_tst,
|
| 267 |
+
x_tst_lengths,
|
| 268 |
+
speakers,
|
| 269 |
+
tones,
|
| 270 |
+
lang_ids,
|
| 271 |
+
bert,
|
| 272 |
+
ja_bert,
|
| 273 |
+
en_bert,
|
| 274 |
+
emo,
|
| 275 |
+
sdp_ratio=sdp_ratio,
|
| 276 |
+
noise_scale=noise_scale,
|
| 277 |
+
noise_scale_w=noise_scale_w,
|
| 278 |
+
length_scale=length_scale,
|
| 279 |
+
)[0][0, 0]
|
| 280 |
+
.data.cpu()
|
| 281 |
+
.float()
|
| 282 |
+
.numpy()
|
| 283 |
+
)
|
| 284 |
+
del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers, ja_bert, en_bert, emo
|
| 285 |
+
if torch.cuda.is_available():
|
| 286 |
+
torch.cuda.empty_cache()
|
| 287 |
+
return audio
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def infer_multilang(
|
| 291 |
+
text,
|
| 292 |
+
sdp_ratio,
|
| 293 |
+
noise_scale,
|
| 294 |
+
noise_scale_w,
|
| 295 |
+
length_scale,
|
| 296 |
+
sid,
|
| 297 |
+
language,
|
| 298 |
+
hps,
|
| 299 |
+
net_g,
|
| 300 |
+
device,
|
| 301 |
+
bert=None,
|
| 302 |
+
clap=None,
|
| 303 |
+
reference_audio=None,
|
| 304 |
+
emotion=None,
|
| 305 |
+
skip_start=False,
|
| 306 |
+
skip_end=False,
|
| 307 |
+
):
|
| 308 |
+
bert, ja_bert, en_bert, phones, tones, lang_ids = [], [], [], [], [], []
|
| 309 |
+
# emo = get_emo_(reference_audio, emotion, sid)
|
| 310 |
+
if isinstance(reference_audio, np.ndarray):
|
| 311 |
+
emo = clap.get_clap_audio_feature(reference_audio, device)
|
| 312 |
+
else:
|
| 313 |
+
emo = clap.get_clap_text_feature(emotion, device)
|
| 314 |
+
emo = torch.squeeze(emo, dim=1)
|
| 315 |
+
for idx, (txt, lang) in enumerate(zip(text, language)):
|
| 316 |
+
skip_start = (idx != 0) or (skip_start and idx == 0)
|
| 317 |
+
skip_end = (idx != len(text) - 1) or (skip_end and idx == len(text) - 1)
|
| 318 |
+
(
|
| 319 |
+
temp_bert,
|
| 320 |
+
temp_ja_bert,
|
| 321 |
+
temp_en_bert,
|
| 322 |
+
temp_phones,
|
| 323 |
+
temp_tones,
|
| 324 |
+
temp_lang_ids,
|
| 325 |
+
) = get_text(txt, lang, bert, hps, device)
|
| 326 |
+
if skip_start:
|
| 327 |
+
temp_bert = temp_bert[:, 3:]
|
| 328 |
+
temp_ja_bert = temp_ja_bert[:, 3:]
|
| 329 |
+
temp_en_bert = temp_en_bert[:, 3:]
|
| 330 |
+
temp_phones = temp_phones[3:]
|
| 331 |
+
temp_tones = temp_tones[3:]
|
| 332 |
+
temp_lang_ids = temp_lang_ids[3:]
|
| 333 |
+
if skip_end:
|
| 334 |
+
temp_bert = temp_bert[:, :-2]
|
| 335 |
+
temp_ja_bert = temp_ja_bert[:, :-2]
|
| 336 |
+
temp_en_bert = temp_en_bert[:, :-2]
|
| 337 |
+
temp_phones = temp_phones[:-2]
|
| 338 |
+
temp_tones = temp_tones[:-2]
|
| 339 |
+
temp_lang_ids = temp_lang_ids[:-2]
|
| 340 |
+
bert.append(temp_bert)
|
| 341 |
+
ja_bert.append(temp_ja_bert)
|
| 342 |
+
en_bert.append(temp_en_bert)
|
| 343 |
+
phones.append(temp_phones)
|
| 344 |
+
tones.append(temp_tones)
|
| 345 |
+
lang_ids.append(temp_lang_ids)
|
| 346 |
+
bert = torch.concatenate(bert, dim=1)
|
| 347 |
+
ja_bert = torch.concatenate(ja_bert, dim=1)
|
| 348 |
+
en_bert = torch.concatenate(en_bert, dim=1)
|
| 349 |
+
phones = torch.concatenate(phones, dim=0)
|
| 350 |
+
tones = torch.concatenate(tones, dim=0)
|
| 351 |
+
lang_ids = torch.concatenate(lang_ids, dim=0)
|
| 352 |
+
with torch.no_grad():
|
| 353 |
+
x_tst = phones.to(device).unsqueeze(0)
|
| 354 |
+
tones = tones.to(device).unsqueeze(0)
|
| 355 |
+
lang_ids = lang_ids.to(device).unsqueeze(0)
|
| 356 |
+
bert = bert.to(device).unsqueeze(0)
|
| 357 |
+
ja_bert = ja_bert.to(device).unsqueeze(0)
|
| 358 |
+
en_bert = en_bert.to(device).unsqueeze(0)
|
| 359 |
+
emo = emo.to(device).unsqueeze(0)
|
| 360 |
+
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
|
| 361 |
+
del phones
|
| 362 |
+
speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
|
| 363 |
+
audio = (
|
| 364 |
+
net_g.infer(
|
| 365 |
+
x_tst,
|
| 366 |
+
x_tst_lengths,
|
| 367 |
+
speakers,
|
| 368 |
+
tones,
|
| 369 |
+
lang_ids,
|
| 370 |
+
bert,
|
| 371 |
+
ja_bert,
|
| 372 |
+
en_bert,
|
| 373 |
+
emo,
|
| 374 |
+
sdp_ratio=sdp_ratio,
|
| 375 |
+
noise_scale=noise_scale,
|
| 376 |
+
noise_scale_w=noise_scale_w,
|
| 377 |
+
length_scale=length_scale,
|
| 378 |
+
)[0][0, 0]
|
| 379 |
+
.data.cpu()
|
| 380 |
+
.float()
|
| 381 |
+
.numpy()
|
| 382 |
+
)
|
| 383 |
+
del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers, ja_bert, en_bert, emo
|
| 384 |
+
if torch.cuda.is_available():
|
| 385 |
+
torch.cuda.empty_cache()
|
| 386 |
+
return audio
|
for_deploy/infer_utils.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from transformers import (
|
| 5 |
+
AutoModelForMaskedLM,
|
| 6 |
+
AutoTokenizer,
|
| 7 |
+
DebertaV2Model,
|
| 8 |
+
DebertaV2Tokenizer,
|
| 9 |
+
ClapModel,
|
| 10 |
+
ClapProcessor,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
from config import config
|
| 14 |
+
from text.japanese import text2sep_kata
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class BertFeature:
|
| 18 |
+
def __init__(self, model_path, language="ZH"):
|
| 19 |
+
self.model_path = model_path
|
| 20 |
+
self.language = language
|
| 21 |
+
self.tokenizer = None
|
| 22 |
+
self.model = None
|
| 23 |
+
self.device = None
|
| 24 |
+
|
| 25 |
+
self._prepare()
|
| 26 |
+
|
| 27 |
+
def _get_device(self, device=config.bert_gen_config.device):
|
| 28 |
+
if (
|
| 29 |
+
sys.platform == "darwin"
|
| 30 |
+
and torch.backends.mps.is_available()
|
| 31 |
+
and device == "cpu"
|
| 32 |
+
):
|
| 33 |
+
device = "mps"
|
| 34 |
+
if not device:
|
| 35 |
+
device = "cuda"
|
| 36 |
+
return device
|
| 37 |
+
|
| 38 |
+
def _prepare(self):
|
| 39 |
+
self.device = self._get_device()
|
| 40 |
+
|
| 41 |
+
if self.language == "EN":
|
| 42 |
+
self.tokenizer = DebertaV2Tokenizer.from_pretrained(self.model_path)
|
| 43 |
+
self.model = DebertaV2Model.from_pretrained(self.model_path).to(self.device)
|
| 44 |
+
else:
|
| 45 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
|
| 46 |
+
self.model = AutoModelForMaskedLM.from_pretrained(self.model_path).to(
|
| 47 |
+
self.device
|
| 48 |
+
)
|
| 49 |
+
self.model.eval()
|
| 50 |
+
|
| 51 |
+
def get_bert_feature(self, text, word2ph):
|
| 52 |
+
if self.language == "JP":
|
| 53 |
+
text = "".join(text2sep_kata(text)[0])
|
| 54 |
+
with torch.no_grad():
|
| 55 |
+
inputs = self.tokenizer(text, return_tensors="pt")
|
| 56 |
+
for i in inputs:
|
| 57 |
+
inputs[i] = inputs[i].to(self.device)
|
| 58 |
+
res = self.model(**inputs, output_hidden_states=True)
|
| 59 |
+
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
|
| 60 |
+
|
| 61 |
+
word2phone = word2ph
|
| 62 |
+
phone_level_feature = []
|
| 63 |
+
for i in range(len(word2phone)):
|
| 64 |
+
repeat_feature = res[i].repeat(word2phone[i], 1)
|
| 65 |
+
phone_level_feature.append(repeat_feature)
|
| 66 |
+
|
| 67 |
+
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
| 68 |
+
|
| 69 |
+
return phone_level_feature.T
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class ClapFeature:
|
| 73 |
+
def __init__(self, model_path):
|
| 74 |
+
self.model_path = model_path
|
| 75 |
+
self.processor = None
|
| 76 |
+
self.model = None
|
| 77 |
+
self.device = None
|
| 78 |
+
|
| 79 |
+
self._prepare()
|
| 80 |
+
|
| 81 |
+
def _get_device(self, device=config.bert_gen_config.device):
|
| 82 |
+
if (
|
| 83 |
+
sys.platform == "darwin"
|
| 84 |
+
and torch.backends.mps.is_available()
|
| 85 |
+
and device == "cpu"
|
| 86 |
+
):
|
| 87 |
+
device = "mps"
|
| 88 |
+
if not device:
|
| 89 |
+
device = "cuda"
|
| 90 |
+
return device
|
| 91 |
+
|
| 92 |
+
def _prepare(self):
|
| 93 |
+
self.device = self._get_device()
|
| 94 |
+
|
| 95 |
+
self.processor = ClapProcessor.from_pretrained(self.model_path)
|
| 96 |
+
self.model = ClapModel.from_pretrained(self.model_path).to(self.device)
|
| 97 |
+
self.model.eval()
|
| 98 |
+
|
| 99 |
+
def get_clap_audio_feature(self, audio_data):
|
| 100 |
+
with torch.no_grad():
|
| 101 |
+
inputs = self.processor(
|
| 102 |
+
audios=audio_data, return_tensors="pt", sampling_rate=48000
|
| 103 |
+
).to(self.device)
|
| 104 |
+
emb = self.model.get_audio_features(**inputs)
|
| 105 |
+
return emb.T
|
| 106 |
+
|
| 107 |
+
def get_clap_text_feature(self, text):
|
| 108 |
+
with torch.no_grad():
|
| 109 |
+
inputs = self.processor(text=text, return_tensors="pt").to(self.device)
|
| 110 |
+
emb = self.model.get_text_features(**inputs)
|
| 111 |
+
return emb.T
|
for_deploy/webui.py
ADDED
|
@@ -0,0 +1,556 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: E402
|
| 2 |
+
import os
|
| 3 |
+
import logging
|
| 4 |
+
import re_matching
|
| 5 |
+
from tools.sentence import split_by_language
|
| 6 |
+
|
| 7 |
+
logging.getLogger("numba").setLevel(logging.WARNING)
|
| 8 |
+
logging.getLogger("markdown_it").setLevel(logging.WARNING)
|
| 9 |
+
logging.getLogger("urllib3").setLevel(logging.WARNING)
|
| 10 |
+
logging.getLogger("matplotlib").setLevel(logging.WARNING)
|
| 11 |
+
|
| 12 |
+
logging.basicConfig(
|
| 13 |
+
level=logging.INFO, format="| %(name)s | %(levelname)s | %(message)s"
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import utils
|
| 20 |
+
from infer import infer, latest_version, get_net_g, infer_multilang
|
| 21 |
+
import gradio as gr
|
| 22 |
+
import webbrowser
|
| 23 |
+
import numpy as np
|
| 24 |
+
from config import config
|
| 25 |
+
from tools.translate import translate
|
| 26 |
+
import librosa
|
| 27 |
+
from infer_utils import BertFeature, ClapFeature
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
net_g = None
|
| 31 |
+
|
| 32 |
+
device = config.webui_config.device
|
| 33 |
+
if device == "mps":
|
| 34 |
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
| 35 |
+
|
| 36 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
| 37 |
+
os.environ["MKL_NUM_THREADS"] = "1"
|
| 38 |
+
|
| 39 |
+
bert_feature_map = {
|
| 40 |
+
"ZH": BertFeature(
|
| 41 |
+
"./bert/chinese-roberta-wwm-ext-large",
|
| 42 |
+
language="ZH",
|
| 43 |
+
),
|
| 44 |
+
"JP": BertFeature(
|
| 45 |
+
"./bert/deberta-v2-large-japanese-char-wwm",
|
| 46 |
+
language="JP",
|
| 47 |
+
),
|
| 48 |
+
"EN": BertFeature(
|
| 49 |
+
"./bert/deberta-v3-large",
|
| 50 |
+
language="EN",
|
| 51 |
+
),
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
clap_feature = ClapFeature("./emotional/clap-htsat-fused")
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def generate_audio(
|
| 58 |
+
slices,
|
| 59 |
+
sdp_ratio,
|
| 60 |
+
noise_scale,
|
| 61 |
+
noise_scale_w,
|
| 62 |
+
length_scale,
|
| 63 |
+
speaker,
|
| 64 |
+
language,
|
| 65 |
+
reference_audio,
|
| 66 |
+
emotion,
|
| 67 |
+
skip_start=False,
|
| 68 |
+
skip_end=False,
|
| 69 |
+
):
|
| 70 |
+
audio_list = []
|
| 71 |
+
# silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
|
| 72 |
+
with torch.no_grad():
|
| 73 |
+
for idx, piece in enumerate(slices):
|
| 74 |
+
skip_start = (idx != 0) and skip_start
|
| 75 |
+
skip_end = (idx != len(slices) - 1) and skip_end
|
| 76 |
+
audio = infer(
|
| 77 |
+
piece,
|
| 78 |
+
reference_audio=reference_audio,
|
| 79 |
+
emotion=emotion,
|
| 80 |
+
sdp_ratio=sdp_ratio,
|
| 81 |
+
noise_scale=noise_scale,
|
| 82 |
+
noise_scale_w=noise_scale_w,
|
| 83 |
+
length_scale=length_scale,
|
| 84 |
+
sid=speaker,
|
| 85 |
+
language=language,
|
| 86 |
+
hps=hps,
|
| 87 |
+
net_g=net_g,
|
| 88 |
+
device=device,
|
| 89 |
+
skip_start=skip_start,
|
| 90 |
+
skip_end=skip_end,
|
| 91 |
+
bert=bert_feature_map,
|
| 92 |
+
clap=clap_feature,
|
| 93 |
+
)
|
| 94 |
+
audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
|
| 95 |
+
audio_list.append(audio16bit)
|
| 96 |
+
# audio_list.append(silence) # 将静音添加到列表中
|
| 97 |
+
return audio_list
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def generate_audio_multilang(
|
| 101 |
+
slices,
|
| 102 |
+
sdp_ratio,
|
| 103 |
+
noise_scale,
|
| 104 |
+
noise_scale_w,
|
| 105 |
+
length_scale,
|
| 106 |
+
speaker,
|
| 107 |
+
language,
|
| 108 |
+
reference_audio,
|
| 109 |
+
emotion,
|
| 110 |
+
skip_start=False,
|
| 111 |
+
skip_end=False,
|
| 112 |
+
):
|
| 113 |
+
audio_list = []
|
| 114 |
+
# silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
|
| 115 |
+
with torch.no_grad():
|
| 116 |
+
for idx, piece in enumerate(slices):
|
| 117 |
+
skip_start = (idx != 0) and skip_start
|
| 118 |
+
skip_end = (idx != len(slices) - 1) and skip_end
|
| 119 |
+
audio = infer_multilang(
|
| 120 |
+
piece,
|
| 121 |
+
reference_audio=reference_audio,
|
| 122 |
+
emotion=emotion,
|
| 123 |
+
sdp_ratio=sdp_ratio,
|
| 124 |
+
noise_scale=noise_scale,
|
| 125 |
+
noise_scale_w=noise_scale_w,
|
| 126 |
+
length_scale=length_scale,
|
| 127 |
+
sid=speaker,
|
| 128 |
+
language=language[idx],
|
| 129 |
+
hps=hps,
|
| 130 |
+
net_g=net_g,
|
| 131 |
+
device=device,
|
| 132 |
+
skip_start=skip_start,
|
| 133 |
+
skip_end=skip_end,
|
| 134 |
+
)
|
| 135 |
+
audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
|
| 136 |
+
audio_list.append(audio16bit)
|
| 137 |
+
# audio_list.append(silence) # 将静音添加到列表中
|
| 138 |
+
return audio_list
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def tts_split(
|
| 142 |
+
text: str,
|
| 143 |
+
speaker,
|
| 144 |
+
sdp_ratio,
|
| 145 |
+
noise_scale,
|
| 146 |
+
noise_scale_w,
|
| 147 |
+
length_scale,
|
| 148 |
+
language,
|
| 149 |
+
cut_by_sent,
|
| 150 |
+
interval_between_para,
|
| 151 |
+
interval_between_sent,
|
| 152 |
+
reference_audio,
|
| 153 |
+
emotion,
|
| 154 |
+
):
|
| 155 |
+
if language == "mix":
|
| 156 |
+
return ("invalid", None)
|
| 157 |
+
while text.find("\n\n") != -1:
|
| 158 |
+
text = text.replace("\n\n", "\n")
|
| 159 |
+
para_list = re_matching.cut_para(text)
|
| 160 |
+
audio_list = []
|
| 161 |
+
if not cut_by_sent:
|
| 162 |
+
for idx, p in enumerate(para_list):
|
| 163 |
+
skip_start = idx != 0
|
| 164 |
+
skip_end = idx != len(para_list) - 1
|
| 165 |
+
audio = infer(
|
| 166 |
+
p,
|
| 167 |
+
reference_audio=reference_audio,
|
| 168 |
+
emotion=emotion,
|
| 169 |
+
sdp_ratio=sdp_ratio,
|
| 170 |
+
noise_scale=noise_scale,
|
| 171 |
+
noise_scale_w=noise_scale_w,
|
| 172 |
+
length_scale=length_scale,
|
| 173 |
+
sid=speaker,
|
| 174 |
+
language=language,
|
| 175 |
+
hps=hps,
|
| 176 |
+
net_g=net_g,
|
| 177 |
+
device=device,
|
| 178 |
+
skip_start=skip_start,
|
| 179 |
+
skip_end=skip_end,
|
| 180 |
+
)
|
| 181 |
+
audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
|
| 182 |
+
audio_list.append(audio16bit)
|
| 183 |
+
silence = np.zeros((int)(44100 * interval_between_para), dtype=np.int16)
|
| 184 |
+
audio_list.append(silence)
|
| 185 |
+
else:
|
| 186 |
+
for idx, p in enumerate(para_list):
|
| 187 |
+
skip_start = idx != 0
|
| 188 |
+
skip_end = idx != len(para_list) - 1
|
| 189 |
+
audio_list_sent = []
|
| 190 |
+
sent_list = re_matching.cut_sent(p)
|
| 191 |
+
for idx, s in enumerate(sent_list):
|
| 192 |
+
skip_start = (idx != 0) and skip_start
|
| 193 |
+
skip_end = (idx != len(sent_list) - 1) and skip_end
|
| 194 |
+
audio = infer(
|
| 195 |
+
s,
|
| 196 |
+
reference_audio=reference_audio,
|
| 197 |
+
emotion=emotion,
|
| 198 |
+
sdp_ratio=sdp_ratio,
|
| 199 |
+
noise_scale=noise_scale,
|
| 200 |
+
noise_scale_w=noise_scale_w,
|
| 201 |
+
length_scale=length_scale,
|
| 202 |
+
sid=speaker,
|
| 203 |
+
language=language,
|
| 204 |
+
hps=hps,
|
| 205 |
+
net_g=net_g,
|
| 206 |
+
device=device,
|
| 207 |
+
skip_start=skip_start,
|
| 208 |
+
skip_end=skip_end,
|
| 209 |
+
)
|
| 210 |
+
audio_list_sent.append(audio)
|
| 211 |
+
silence = np.zeros((int)(44100 * interval_between_sent))
|
| 212 |
+
audio_list_sent.append(silence)
|
| 213 |
+
if (interval_between_para - interval_between_sent) > 0:
|
| 214 |
+
silence = np.zeros(
|
| 215 |
+
(int)(44100 * (interval_between_para - interval_between_sent))
|
| 216 |
+
)
|
| 217 |
+
audio_list_sent.append(silence)
|
| 218 |
+
audio16bit = gr.processing_utils.convert_to_16_bit_wav(
|
| 219 |
+
np.concatenate(audio_list_sent)
|
| 220 |
+
) # 对完整句子做音量归一
|
| 221 |
+
audio_list.append(audio16bit)
|
| 222 |
+
audio_concat = np.concatenate(audio_list)
|
| 223 |
+
return ("Success", (44100, audio_concat))
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def tts_fn(
|
| 227 |
+
text: str,
|
| 228 |
+
speaker,
|
| 229 |
+
sdp_ratio,
|
| 230 |
+
noise_scale,
|
| 231 |
+
noise_scale_w,
|
| 232 |
+
length_scale,
|
| 233 |
+
language,
|
| 234 |
+
reference_audio,
|
| 235 |
+
emotion,
|
| 236 |
+
prompt_mode,
|
| 237 |
+
):
|
| 238 |
+
if prompt_mode == "Audio prompt":
|
| 239 |
+
if reference_audio == None:
|
| 240 |
+
return ("Invalid audio prompt", None)
|
| 241 |
+
else:
|
| 242 |
+
reference_audio = load_audio(reference_audio)[1]
|
| 243 |
+
else:
|
| 244 |
+
reference_audio = None
|
| 245 |
+
audio_list = []
|
| 246 |
+
if language == "mix":
|
| 247 |
+
bool_valid, str_valid = re_matching.validate_text(text)
|
| 248 |
+
if not bool_valid:
|
| 249 |
+
return str_valid, (
|
| 250 |
+
hps.data.sampling_rate,
|
| 251 |
+
np.concatenate([np.zeros(hps.data.sampling_rate // 2)]),
|
| 252 |
+
)
|
| 253 |
+
result = []
|
| 254 |
+
for slice in re_matching.text_matching(text):
|
| 255 |
+
_speaker = slice.pop()
|
| 256 |
+
temp_contant = []
|
| 257 |
+
temp_lang = []
|
| 258 |
+
for lang, content in slice:
|
| 259 |
+
if "|" in content:
|
| 260 |
+
temp = []
|
| 261 |
+
temp_ = []
|
| 262 |
+
for i in content.split("|"):
|
| 263 |
+
if i != "":
|
| 264 |
+
temp.append([i])
|
| 265 |
+
temp_.append([lang])
|
| 266 |
+
else:
|
| 267 |
+
temp.append([])
|
| 268 |
+
temp_.append([])
|
| 269 |
+
temp_contant += temp
|
| 270 |
+
temp_lang += temp_
|
| 271 |
+
else:
|
| 272 |
+
if len(temp_contant) == 0:
|
| 273 |
+
temp_contant.append([])
|
| 274 |
+
temp_lang.append([])
|
| 275 |
+
temp_contant[-1].append(content)
|
| 276 |
+
temp_lang[-1].append(lang)
|
| 277 |
+
for i, j in zip(temp_lang, temp_contant):
|
| 278 |
+
result.append([*zip(i, j), _speaker])
|
| 279 |
+
for i, one in enumerate(result):
|
| 280 |
+
skip_start = i != 0
|
| 281 |
+
skip_end = i != len(result) - 1
|
| 282 |
+
_speaker = one.pop()
|
| 283 |
+
idx = 0
|
| 284 |
+
while idx < len(one):
|
| 285 |
+
text_to_generate = []
|
| 286 |
+
lang_to_generate = []
|
| 287 |
+
while True:
|
| 288 |
+
lang, content = one[idx]
|
| 289 |
+
temp_text = [content]
|
| 290 |
+
if len(text_to_generate) > 0:
|
| 291 |
+
text_to_generate[-1] += [temp_text.pop(0)]
|
| 292 |
+
lang_to_generate[-1] += [lang]
|
| 293 |
+
if len(temp_text) > 0:
|
| 294 |
+
text_to_generate += [[i] for i in temp_text]
|
| 295 |
+
lang_to_generate += [[lang]] * len(temp_text)
|
| 296 |
+
if idx + 1 < len(one):
|
| 297 |
+
idx += 1
|
| 298 |
+
else:
|
| 299 |
+
break
|
| 300 |
+
skip_start = (idx != 0) and skip_start
|
| 301 |
+
skip_end = (idx != len(one) - 1) and skip_end
|
| 302 |
+
print(text_to_generate, lang_to_generate)
|
| 303 |
+
audio_list.extend(
|
| 304 |
+
generate_audio_multilang(
|
| 305 |
+
text_to_generate,
|
| 306 |
+
sdp_ratio,
|
| 307 |
+
noise_scale,
|
| 308 |
+
noise_scale_w,
|
| 309 |
+
length_scale,
|
| 310 |
+
_speaker,
|
| 311 |
+
lang_to_generate,
|
| 312 |
+
reference_audio,
|
| 313 |
+
emotion,
|
| 314 |
+
skip_start,
|
| 315 |
+
skip_end,
|
| 316 |
+
)
|
| 317 |
+
)
|
| 318 |
+
idx += 1
|
| 319 |
+
elif language.lower() == "auto":
|
| 320 |
+
for idx, slice in enumerate(text.split("|")):
|
| 321 |
+
if slice == "":
|
| 322 |
+
continue
|
| 323 |
+
skip_start = idx != 0
|
| 324 |
+
skip_end = idx != len(text.split("|")) - 1
|
| 325 |
+
sentences_list = split_by_language(
|
| 326 |
+
slice, target_languages=["zh", "ja", "en"]
|
| 327 |
+
)
|
| 328 |
+
idx = 0
|
| 329 |
+
while idx < len(sentences_list):
|
| 330 |
+
text_to_generate = []
|
| 331 |
+
lang_to_generate = []
|
| 332 |
+
while True:
|
| 333 |
+
content, lang = sentences_list[idx]
|
| 334 |
+
temp_text = [content]
|
| 335 |
+
lang = lang.upper()
|
| 336 |
+
if lang == "JA":
|
| 337 |
+
lang = "JP"
|
| 338 |
+
if len(text_to_generate) > 0:
|
| 339 |
+
text_to_generate[-1] += [temp_text.pop(0)]
|
| 340 |
+
lang_to_generate[-1] += [lang]
|
| 341 |
+
if len(temp_text) > 0:
|
| 342 |
+
text_to_generate += [[i] for i in temp_text]
|
| 343 |
+
lang_to_generate += [[lang]] * len(temp_text)
|
| 344 |
+
if idx + 1 < len(sentences_list):
|
| 345 |
+
idx += 1
|
| 346 |
+
else:
|
| 347 |
+
break
|
| 348 |
+
skip_start = (idx != 0) and skip_start
|
| 349 |
+
skip_end = (idx != len(sentences_list) - 1) and skip_end
|
| 350 |
+
print(text_to_generate, lang_to_generate)
|
| 351 |
+
audio_list.extend(
|
| 352 |
+
generate_audio_multilang(
|
| 353 |
+
text_to_generate,
|
| 354 |
+
sdp_ratio,
|
| 355 |
+
noise_scale,
|
| 356 |
+
noise_scale_w,
|
| 357 |
+
length_scale,
|
| 358 |
+
speaker,
|
| 359 |
+
lang_to_generate,
|
| 360 |
+
reference_audio,
|
| 361 |
+
emotion,
|
| 362 |
+
skip_start,
|
| 363 |
+
skip_end,
|
| 364 |
+
)
|
| 365 |
+
)
|
| 366 |
+
idx += 1
|
| 367 |
+
else:
|
| 368 |
+
audio_list.extend(
|
| 369 |
+
generate_audio(
|
| 370 |
+
text.split("|"),
|
| 371 |
+
sdp_ratio,
|
| 372 |
+
noise_scale,
|
| 373 |
+
noise_scale_w,
|
| 374 |
+
length_scale,
|
| 375 |
+
speaker,
|
| 376 |
+
language,
|
| 377 |
+
reference_audio,
|
| 378 |
+
emotion,
|
| 379 |
+
)
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
audio_concat = np.concatenate(audio_list)
|
| 383 |
+
return "Success", (hps.data.sampling_rate, audio_concat)
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def load_audio(path):
|
| 387 |
+
audio, sr = librosa.load(path, 48000)
|
| 388 |
+
# audio = librosa.resample(audio, 44100, 48000)
|
| 389 |
+
return sr, audio
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
def gr_util(item):
|
| 393 |
+
if item == "Text prompt":
|
| 394 |
+
return {"visible": True, "__type__": "update"}, {
|
| 395 |
+
"visible": False,
|
| 396 |
+
"__type__": "update",
|
| 397 |
+
}
|
| 398 |
+
else:
|
| 399 |
+
return {"visible": False, "__type__": "update"}, {
|
| 400 |
+
"visible": True,
|
| 401 |
+
"__type__": "update",
|
| 402 |
+
}
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
if __name__ == "__main__":
|
| 406 |
+
if config.webui_config.debug:
|
| 407 |
+
logger.info("Enable DEBUG-LEVEL log")
|
| 408 |
+
logging.basicConfig(level=logging.DEBUG)
|
| 409 |
+
hps = utils.get_hparams_from_file(config.webui_config.config_path)
|
| 410 |
+
# 若config.json中未指定版本则默认为最新版本
|
| 411 |
+
version = hps.version if hasattr(hps, "version") else latest_version
|
| 412 |
+
net_g = get_net_g(
|
| 413 |
+
model_path=config.webui_config.model, version=version, device=device, hps=hps
|
| 414 |
+
)
|
| 415 |
+
speaker_ids = hps.data.spk2id
|
| 416 |
+
speakers = list(speaker_ids.keys())
|
| 417 |
+
languages = ["ZH", "JP", "EN", "mix", "auto"]
|
| 418 |
+
with gr.Blocks() as app:
|
| 419 |
+
with gr.Row():
|
| 420 |
+
with gr.Column():
|
| 421 |
+
text = gr.TextArea(
|
| 422 |
+
label="输入文本内容",
|
| 423 |
+
placeholder="""
|
| 424 |
+
如果你选择语言为\'mix\',必须按照格式输入,否则报错:
|
| 425 |
+
格式举例(zh是中文,jp是日语,不区分大小写;说话人举例:gongzi):
|
| 426 |
+
[说话人1]<zh>你好,こんにちは! <jp>こんにちは,世界。
|
| 427 |
+
[说话人2]<zh>你好吗?<jp>元気ですか?
|
| 428 |
+
[说话人3]<zh>谢谢。<jp>どういたしまして。
|
| 429 |
+
...
|
| 430 |
+
另外,所有的语言选项都可以用'|'分割长段实现分句生成。
|
| 431 |
+
""",
|
| 432 |
+
)
|
| 433 |
+
trans = gr.Button("中翻日", variant="primary")
|
| 434 |
+
slicer = gr.Button("快速切分", variant="primary")
|
| 435 |
+
speaker = gr.Dropdown(
|
| 436 |
+
choices=speakers, value=speakers[0], label="Speaker"
|
| 437 |
+
)
|
| 438 |
+
_ = gr.Markdown(
|
| 439 |
+
value="提示模式(Prompt mode):可选文字提示或音频提示,用于生成文字或音频指定风格的声音。\n"
|
| 440 |
+
)
|
| 441 |
+
prompt_mode = gr.Radio(
|
| 442 |
+
["Text prompt", "Audio prompt"],
|
| 443 |
+
label="Prompt Mode",
|
| 444 |
+
value="Text prompt",
|
| 445 |
+
)
|
| 446 |
+
text_prompt = gr.Textbox(
|
| 447 |
+
label="Text prompt",
|
| 448 |
+
placeholder="用文字描述生成风格。如:Happy",
|
| 449 |
+
value="Happy",
|
| 450 |
+
visible=True,
|
| 451 |
+
)
|
| 452 |
+
audio_prompt = gr.Audio(
|
| 453 |
+
label="Audio prompt", type="filepath", visible=False
|
| 454 |
+
)
|
| 455 |
+
sdp_ratio = gr.Slider(
|
| 456 |
+
minimum=0, maximum=1, value=0.2, step=0.1, label="SDP Ratio"
|
| 457 |
+
)
|
| 458 |
+
noise_scale = gr.Slider(
|
| 459 |
+
minimum=0.1, maximum=2, value=0.6, step=0.1, label="Noise"
|
| 460 |
+
)
|
| 461 |
+
noise_scale_w = gr.Slider(
|
| 462 |
+
minimum=0.1, maximum=2, value=0.8, step=0.1, label="Noise_W"
|
| 463 |
+
)
|
| 464 |
+
length_scale = gr.Slider(
|
| 465 |
+
minimum=0.1, maximum=2, value=1.0, step=0.1, label="Length"
|
| 466 |
+
)
|
| 467 |
+
language = gr.Dropdown(
|
| 468 |
+
choices=languages, value=languages[0], label="Language"
|
| 469 |
+
)
|
| 470 |
+
btn = gr.Button("生成音频!", variant="primary")
|
| 471 |
+
with gr.Column():
|
| 472 |
+
with gr.Row():
|
| 473 |
+
with gr.Column():
|
| 474 |
+
interval_between_sent = gr.Slider(
|
| 475 |
+
minimum=0,
|
| 476 |
+
maximum=5,
|
| 477 |
+
value=0.2,
|
| 478 |
+
step=0.1,
|
| 479 |
+
label="句间停顿(秒),勾选按句切分才生效",
|
| 480 |
+
)
|
| 481 |
+
interval_between_para = gr.Slider(
|
| 482 |
+
minimum=0,
|
| 483 |
+
maximum=10,
|
| 484 |
+
value=1,
|
| 485 |
+
step=0.1,
|
| 486 |
+
label="段间停顿(秒),需要大于句间停顿才有效",
|
| 487 |
+
)
|
| 488 |
+
opt_cut_by_sent = gr.Checkbox(
|
| 489 |
+
label="按句切分 在按段落切分的基础上再按句子切分文本"
|
| 490 |
+
)
|
| 491 |
+
slicer = gr.Button("切分生成", variant="primary")
|
| 492 |
+
text_output = gr.Textbox(label="状态信息")
|
| 493 |
+
audio_output = gr.Audio(label="输出音频")
|
| 494 |
+
# explain_image = gr.Image(
|
| 495 |
+
# label="参数解释信息",
|
| 496 |
+
# show_label=True,
|
| 497 |
+
# show_share_button=False,
|
| 498 |
+
# show_download_button=False,
|
| 499 |
+
# value=os.path.abspath("./img/参数说明.png"),
|
| 500 |
+
# )
|
| 501 |
+
btn.click(
|
| 502 |
+
tts_fn,
|
| 503 |
+
inputs=[
|
| 504 |
+
text,
|
| 505 |
+
speaker,
|
| 506 |
+
sdp_ratio,
|
| 507 |
+
noise_scale,
|
| 508 |
+
noise_scale_w,
|
| 509 |
+
length_scale,
|
| 510 |
+
language,
|
| 511 |
+
audio_prompt,
|
| 512 |
+
text_prompt,
|
| 513 |
+
prompt_mode,
|
| 514 |
+
],
|
| 515 |
+
outputs=[text_output, audio_output],
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
trans.click(
|
| 519 |
+
translate,
|
| 520 |
+
inputs=[text],
|
| 521 |
+
outputs=[text],
|
| 522 |
+
)
|
| 523 |
+
slicer.click(
|
| 524 |
+
tts_split,
|
| 525 |
+
inputs=[
|
| 526 |
+
text,
|
| 527 |
+
speaker,
|
| 528 |
+
sdp_ratio,
|
| 529 |
+
noise_scale,
|
| 530 |
+
noise_scale_w,
|
| 531 |
+
length_scale,
|
| 532 |
+
language,
|
| 533 |
+
opt_cut_by_sent,
|
| 534 |
+
interval_between_para,
|
| 535 |
+
interval_between_sent,
|
| 536 |
+
audio_prompt,
|
| 537 |
+
text_prompt,
|
| 538 |
+
],
|
| 539 |
+
outputs=[text_output, audio_output],
|
| 540 |
+
)
|
| 541 |
+
|
| 542 |
+
prompt_mode.change(
|
| 543 |
+
lambda x: gr_util(x),
|
| 544 |
+
inputs=[prompt_mode],
|
| 545 |
+
outputs=[text_prompt, audio_prompt],
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
audio_prompt.upload(
|
| 549 |
+
lambda x: load_audio(x),
|
| 550 |
+
inputs=[audio_prompt],
|
| 551 |
+
outputs=[audio_prompt],
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
print("推理页面已开启!")
|
| 555 |
+
webbrowser.open(f"http://127.0.0.1:{config.webui_config.port}")
|
| 556 |
+
app.launch(share=config.webui_config.share, server_port=config.webui_config.port)
|
infer.py
CHANGED
|
@@ -10,7 +10,8 @@
|
|
| 10 |
import torch
|
| 11 |
import commons
|
| 12 |
from text import cleaned_text_to_sequence, get_bert
|
| 13 |
-
|
|
|
|
| 14 |
from text.cleaner import clean_text
|
| 15 |
import utils
|
| 16 |
import numpy as np
|
|
@@ -20,47 +21,47 @@ from text.symbols import symbols
|
|
| 20 |
|
| 21 |
# from oldVersion.V210.models import SynthesizerTrn as V210SynthesizerTrn
|
| 22 |
# from oldVersion.V210.text import symbols as V210symbols
|
| 23 |
-
from oldVersion.V200.models import SynthesizerTrn as V200SynthesizerTrn
|
| 24 |
-
from oldVersion.V200.text import symbols as V200symbols
|
| 25 |
-
from oldVersion.V111.models import SynthesizerTrn as V111SynthesizerTrn
|
| 26 |
-
from oldVersion.V111.text import symbols as V111symbols
|
| 27 |
-
from oldVersion.V110.models import SynthesizerTrn as V110SynthesizerTrn
|
| 28 |
-
from oldVersion.V110.text import symbols as V110symbols
|
| 29 |
-
from oldVersion.V101.models import SynthesizerTrn as V101SynthesizerTrn
|
| 30 |
-
from oldVersion.V101.text import symbols as V101symbols
|
| 31 |
|
| 32 |
-
from oldVersion import V111, V110, V101, V200
|
| 33 |
|
| 34 |
# 当前版本信息
|
| 35 |
-
latest_version = "2.
|
| 36 |
|
| 37 |
# 版本兼容
|
| 38 |
SynthesizerTrnMap = {
|
| 39 |
# "2.1": V210SynthesizerTrn,
|
| 40 |
-
"2.0.2-fix": V200SynthesizerTrn,
|
| 41 |
-
"2.0.1": V200SynthesizerTrn,
|
| 42 |
-
"2.0": V200SynthesizerTrn,
|
| 43 |
-
"1.1.1-fix": V111SynthesizerTrn,
|
| 44 |
-
"1.1.1": V111SynthesizerTrn,
|
| 45 |
-
"1.1": V110SynthesizerTrn,
|
| 46 |
-
"1.1.0": V110SynthesizerTrn,
|
| 47 |
-
"1.0.1": V101SynthesizerTrn,
|
| 48 |
-
"1.0": V101SynthesizerTrn,
|
| 49 |
-
"1.0.0": V101SynthesizerTrn,
|
| 50 |
}
|
| 51 |
|
| 52 |
symbolsMap = {
|
| 53 |
# "2.1": V210symbols,
|
| 54 |
-
"2.0.2-fix": V200symbols,
|
| 55 |
-
"2.0.1": V200symbols,
|
| 56 |
-
"2.0": V200symbols,
|
| 57 |
-
"1.1.1-fix": V111symbols,
|
| 58 |
-
"1.1.1": V111symbols,
|
| 59 |
-
"1.1": V110symbols,
|
| 60 |
-
"1.1.0": V110symbols,
|
| 61 |
-
"1.0.1": V101symbols,
|
| 62 |
-
"1.0": V101symbols,
|
| 63 |
-
"1.0.0": V101symbols,
|
| 64 |
}
|
| 65 |
|
| 66 |
|
|
@@ -98,7 +99,8 @@ def get_net_g(model_path: str, version: str, device: str, hps):
|
|
| 98 |
return net_g
|
| 99 |
|
| 100 |
|
| 101 |
-
def get_text(text, language_str, hps, device):
|
|
|
|
| 102 |
# 在此处实现当前版本的get_text
|
| 103 |
norm_text, phone, tone, word2ph = clean_text(text, language_str)
|
| 104 |
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
|
|
@@ -110,21 +112,23 @@ def get_text(text, language_str, hps, device):
|
|
| 110 |
for i in range(len(word2ph)):
|
| 111 |
word2ph[i] = word2ph[i] * 2
|
| 112 |
word2ph[0] += 1
|
| 113 |
-
bert_ori = get_bert(
|
|
|
|
|
|
|
| 114 |
del word2ph
|
| 115 |
assert bert_ori.shape[-1] == len(phone), phone
|
| 116 |
|
| 117 |
if language_str == "ZH":
|
| 118 |
bert = bert_ori
|
| 119 |
-
ja_bert = torch.
|
| 120 |
-
en_bert = torch.
|
| 121 |
elif language_str == "JP":
|
| 122 |
-
bert = torch.
|
| 123 |
ja_bert = bert_ori
|
| 124 |
-
en_bert = torch.
|
| 125 |
elif language_str == "EN":
|
| 126 |
-
bert = torch.
|
| 127 |
-
ja_bert = torch.
|
| 128 |
en_bert = bert_ori
|
| 129 |
else:
|
| 130 |
raise ValueError("language_str should be ZH, JP or EN")
|
|
@@ -154,49 +158,54 @@ def infer(
|
|
| 154 |
reference_audio=None,
|
| 155 |
skip_start=False,
|
| 156 |
skip_end=False,
|
|
|
|
|
|
|
| 157 |
):
|
| 158 |
# 2.2版本参数位置变了
|
| 159 |
# 2.1 参数新增 emotion reference_audio skip_start skip_end
|
| 160 |
# inferMap_V3 = {
|
| 161 |
# "2.1": V210.infer,
|
| 162 |
-
|
| 163 |
# 支持中日英三语版本
|
| 164 |
inferMap_V2 = {
|
| 165 |
-
"2.0.2-fix": V200.infer,
|
| 166 |
-
"2.0.1": V200.infer,
|
| 167 |
-
"2.0": V200.infer,
|
| 168 |
-
"1.1.1-fix": V111.infer_fix,
|
| 169 |
-
"1.1.1": V111.infer,
|
| 170 |
-
"1.1": V110.infer,
|
| 171 |
-
"1.1.0": V110.infer,
|
| 172 |
}
|
| 173 |
# 仅支持中文版本
|
| 174 |
# 在测试中,并未发现两个版本的模型不能互相通用
|
| 175 |
inferMap_V1 = {
|
| 176 |
-
"1.0.1": V101.infer,
|
| 177 |
-
"1.0": V101.infer,
|
| 178 |
-
"1.0.0": V101.infer,
|
| 179 |
}
|
| 180 |
version = hps.version if hasattr(hps, "version") else latest_version
|
| 181 |
# 非当前版本,根据版本号选择合适的infer
|
| 182 |
if version != latest_version:
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
|
|
|
|
|
|
|
|
|
| 200 |
if version in inferMap_V2.keys():
|
| 201 |
return inferMap_V2[version](
|
| 202 |
text,
|
|
@@ -224,14 +233,19 @@ def infer(
|
|
| 224 |
)
|
| 225 |
# 在此处实现当前版本的推理
|
| 226 |
# emo = get_emo_(reference_audio, emotion, sid)
|
| 227 |
-
if isinstance(reference_audio, np.ndarray):
|
| 228 |
-
|
| 229 |
-
else:
|
| 230 |
-
|
| 231 |
-
emo = torch.squeeze(emo, dim=1)
|
| 232 |
|
| 233 |
bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
|
| 234 |
-
text,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
)
|
| 236 |
if skip_start:
|
| 237 |
phones = phones[3:]
|
|
@@ -255,7 +269,7 @@ def infer(
|
|
| 255 |
ja_bert = ja_bert.to(device).unsqueeze(0)
|
| 256 |
en_bert = en_bert.to(device).unsqueeze(0)
|
| 257 |
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
|
| 258 |
-
emo = emo.to(device).unsqueeze(0)
|
| 259 |
del phones
|
| 260 |
speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
|
| 261 |
audio = (
|
|
@@ -268,7 +282,6 @@ def infer(
|
|
| 268 |
bert,
|
| 269 |
ja_bert,
|
| 270 |
en_bert,
|
| 271 |
-
emo,
|
| 272 |
sdp_ratio=sdp_ratio,
|
| 273 |
noise_scale=noise_scale,
|
| 274 |
noise_scale_w=noise_scale_w,
|
|
@@ -278,7 +291,16 @@ def infer(
|
|
| 278 |
.float()
|
| 279 |
.numpy()
|
| 280 |
)
|
| 281 |
-
del
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 282 |
if torch.cuda.is_available():
|
| 283 |
torch.cuda.empty_cache()
|
| 284 |
return audio
|
|
@@ -302,14 +324,14 @@ def infer_multilang(
|
|
| 302 |
):
|
| 303 |
bert, ja_bert, en_bert, phones, tones, lang_ids = [], [], [], [], [], []
|
| 304 |
# emo = get_emo_(reference_audio, emotion, sid)
|
| 305 |
-
if isinstance(reference_audio, np.ndarray):
|
| 306 |
-
|
| 307 |
-
else:
|
| 308 |
-
|
| 309 |
-
emo = torch.squeeze(emo, dim=1)
|
| 310 |
for idx, (txt, lang) in enumerate(zip(text, language)):
|
| 311 |
-
|
| 312 |
-
|
| 313 |
(
|
| 314 |
temp_bert,
|
| 315 |
temp_ja_bert,
|
|
@@ -318,14 +340,14 @@ def infer_multilang(
|
|
| 318 |
temp_tones,
|
| 319 |
temp_lang_ids,
|
| 320 |
) = get_text(txt, lang, hps, device)
|
| 321 |
-
if
|
| 322 |
temp_bert = temp_bert[:, 3:]
|
| 323 |
temp_ja_bert = temp_ja_bert[:, 3:]
|
| 324 |
temp_en_bert = temp_en_bert[:, 3:]
|
| 325 |
temp_phones = temp_phones[3:]
|
| 326 |
temp_tones = temp_tones[3:]
|
| 327 |
temp_lang_ids = temp_lang_ids[3:]
|
| 328 |
-
if
|
| 329 |
temp_bert = temp_bert[:, :-2]
|
| 330 |
temp_ja_bert = temp_ja_bert[:, :-2]
|
| 331 |
temp_en_bert = temp_en_bert[:, :-2]
|
|
@@ -351,7 +373,7 @@ def infer_multilang(
|
|
| 351 |
bert = bert.to(device).unsqueeze(0)
|
| 352 |
ja_bert = ja_bert.to(device).unsqueeze(0)
|
| 353 |
en_bert = en_bert.to(device).unsqueeze(0)
|
| 354 |
-
emo = emo.to(device).unsqueeze(0)
|
| 355 |
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
|
| 356 |
del phones
|
| 357 |
speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
|
|
@@ -365,7 +387,6 @@ def infer_multilang(
|
|
| 365 |
bert,
|
| 366 |
ja_bert,
|
| 367 |
en_bert,
|
| 368 |
-
emo,
|
| 369 |
sdp_ratio=sdp_ratio,
|
| 370 |
noise_scale=noise_scale,
|
| 371 |
noise_scale_w=noise_scale_w,
|
|
@@ -375,7 +396,16 @@ def infer_multilang(
|
|
| 375 |
.float()
|
| 376 |
.numpy()
|
| 377 |
)
|
| 378 |
-
del
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 379 |
if torch.cuda.is_available():
|
| 380 |
torch.cuda.empty_cache()
|
| 381 |
return audio
|
|
|
|
| 10 |
import torch
|
| 11 |
import commons
|
| 12 |
from text import cleaned_text_to_sequence, get_bert
|
| 13 |
+
|
| 14 |
+
# from clap_wrapper import get_clap_audio_feature, get_clap_text_feature
|
| 15 |
from text.cleaner import clean_text
|
| 16 |
import utils
|
| 17 |
import numpy as np
|
|
|
|
| 21 |
|
| 22 |
# from oldVersion.V210.models import SynthesizerTrn as V210SynthesizerTrn
|
| 23 |
# from oldVersion.V210.text import symbols as V210symbols
|
| 24 |
+
# from oldVersion.V200.models import SynthesizerTrn as V200SynthesizerTrn
|
| 25 |
+
# from oldVersion.V200.text import symbols as V200symbols
|
| 26 |
+
# from oldVersion.V111.models import SynthesizerTrn as V111SynthesizerTrn
|
| 27 |
+
# from oldVersion.V111.text import symbols as V111symbols
|
| 28 |
+
# from oldVersion.V110.models import SynthesizerTrn as V110SynthesizerTrn
|
| 29 |
+
# from oldVersion.V110.text import symbols as V110symbols
|
| 30 |
+
# from oldVersion.V101.models import SynthesizerTrn as V101SynthesizerTrn
|
| 31 |
+
# from oldVersion.V101.text import symbols as V101symbols
|
| 32 |
|
| 33 |
+
# from oldVersion import V111, V110, V101, V200, V210
|
| 34 |
|
| 35 |
# 当前版本信息
|
| 36 |
+
latest_version = "2.3"
|
| 37 |
|
| 38 |
# 版本兼容
|
| 39 |
SynthesizerTrnMap = {
|
| 40 |
# "2.1": V210SynthesizerTrn,
|
| 41 |
+
# "2.0.2-fix": V200SynthesizerTrn,
|
| 42 |
+
# "2.0.1": V200SynthesizerTrn,
|
| 43 |
+
# "2.0": V200SynthesizerTrn,
|
| 44 |
+
# "1.1.1-fix": V111SynthesizerTrn,
|
| 45 |
+
# "1.1.1": V111SynthesizerTrn,
|
| 46 |
+
# "1.1": V110SynthesizerTrn,
|
| 47 |
+
# "1.1.0": V110SynthesizerTrn,
|
| 48 |
+
# "1.0.1": V101SynthesizerTrn,
|
| 49 |
+
# "1.0": V101SynthesizerTrn,
|
| 50 |
+
# "1.0.0": V101SynthesizerTrn,
|
| 51 |
}
|
| 52 |
|
| 53 |
symbolsMap = {
|
| 54 |
# "2.1": V210symbols,
|
| 55 |
+
# "2.0.2-fix": V200symbols,
|
| 56 |
+
# "2.0.1": V200symbols,
|
| 57 |
+
# "2.0": V200symbols,
|
| 58 |
+
# "1.1.1-fix": V111symbols,
|
| 59 |
+
# "1.1.1": V111symbols,
|
| 60 |
+
# "1.1": V110symbols,
|
| 61 |
+
# "1.1.0": V110symbols,
|
| 62 |
+
# "1.0.1": V101symbols,
|
| 63 |
+
# "1.0": V101symbols,
|
| 64 |
+
# "1.0.0": V101symbols,
|
| 65 |
}
|
| 66 |
|
| 67 |
|
|
|
|
| 99 |
return net_g
|
| 100 |
|
| 101 |
|
| 102 |
+
def get_text(text, language_str, hps, device, style_text=None, style_weight=0.7):
|
| 103 |
+
style_text = None if style_text == "" else style_text
|
| 104 |
# 在此处实现当前版本的get_text
|
| 105 |
norm_text, phone, tone, word2ph = clean_text(text, language_str)
|
| 106 |
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
|
|
|
|
| 112 |
for i in range(len(word2ph)):
|
| 113 |
word2ph[i] = word2ph[i] * 2
|
| 114 |
word2ph[0] += 1
|
| 115 |
+
bert_ori = get_bert(
|
| 116 |
+
norm_text, word2ph, language_str, device, style_text, style_weight
|
| 117 |
+
)
|
| 118 |
del word2ph
|
| 119 |
assert bert_ori.shape[-1] == len(phone), phone
|
| 120 |
|
| 121 |
if language_str == "ZH":
|
| 122 |
bert = bert_ori
|
| 123 |
+
ja_bert = torch.randn(1024, len(phone))
|
| 124 |
+
en_bert = torch.randn(1024, len(phone))
|
| 125 |
elif language_str == "JP":
|
| 126 |
+
bert = torch.randn(1024, len(phone))
|
| 127 |
ja_bert = bert_ori
|
| 128 |
+
en_bert = torch.randn(1024, len(phone))
|
| 129 |
elif language_str == "EN":
|
| 130 |
+
bert = torch.randn(1024, len(phone))
|
| 131 |
+
ja_bert = torch.randn(1024, len(phone))
|
| 132 |
en_bert = bert_ori
|
| 133 |
else:
|
| 134 |
raise ValueError("language_str should be ZH, JP or EN")
|
|
|
|
| 158 |
reference_audio=None,
|
| 159 |
skip_start=False,
|
| 160 |
skip_end=False,
|
| 161 |
+
style_text=None,
|
| 162 |
+
style_weight=0.7,
|
| 163 |
):
|
| 164 |
# 2.2版本参数位置变了
|
| 165 |
# 2.1 参数新增 emotion reference_audio skip_start skip_end
|
| 166 |
# inferMap_V3 = {
|
| 167 |
# "2.1": V210.infer,
|
| 168 |
+
}
|
| 169 |
# 支持中日英三语版本
|
| 170 |
inferMap_V2 = {
|
| 171 |
+
# "2.0.2-fix": V200.infer,
|
| 172 |
+
# "2.0.1": V200.infer,
|
| 173 |
+
# "2.0": V200.infer,
|
| 174 |
+
# "1.1.1-fix": V111.infer_fix,
|
| 175 |
+
# "1.1.1": V111.infer,
|
| 176 |
+
# "1.1": V110.infer,
|
| 177 |
+
# "1.1.0": V110.infer,
|
| 178 |
}
|
| 179 |
# 仅支持中文版本
|
| 180 |
# 在测试中,并未发现两个版本的模型不能互相通用
|
| 181 |
inferMap_V1 = {
|
| 182 |
+
# "1.0.1": V101.infer,
|
| 183 |
+
# "1.0": V101.infer,
|
| 184 |
+
# "1.0.0": V101.infer,
|
| 185 |
}
|
| 186 |
version = hps.version if hasattr(hps, "version") else latest_version
|
| 187 |
# 非当前版本,根据版本号选择合适的infer
|
| 188 |
if version != latest_version:
|
| 189 |
+
if version in inferMap_V3.keys():
|
| 190 |
+
emotion = 0
|
| 191 |
+
return inferMap_V3[version](
|
| 192 |
+
text,
|
| 193 |
+
sdp_ratio,
|
| 194 |
+
noise_scale,
|
| 195 |
+
noise_scale_w,
|
| 196 |
+
length_scale,
|
| 197 |
+
sid,
|
| 198 |
+
language,
|
| 199 |
+
hps,
|
| 200 |
+
net_g,
|
| 201 |
+
device,
|
| 202 |
+
reference_audio,
|
| 203 |
+
emotion,
|
| 204 |
+
skip_start,
|
| 205 |
+
skip_end,
|
| 206 |
+
style_text,
|
| 207 |
+
style_weight,
|
| 208 |
+
)
|
| 209 |
if version in inferMap_V2.keys():
|
| 210 |
return inferMap_V2[version](
|
| 211 |
text,
|
|
|
|
| 233 |
)
|
| 234 |
# 在此处实现当前版本的推理
|
| 235 |
# emo = get_emo_(reference_audio, emotion, sid)
|
| 236 |
+
# if isinstance(reference_audio, np.ndarray):
|
| 237 |
+
# emo = get_clap_audio_feature(reference_audio, device)
|
| 238 |
+
# else:
|
| 239 |
+
# emo = get_clap_text_feature(emotion, device)
|
| 240 |
+
# emo = torch.squeeze(emo, dim=1)
|
| 241 |
|
| 242 |
bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
|
| 243 |
+
text,
|
| 244 |
+
language,
|
| 245 |
+
hps,
|
| 246 |
+
device,
|
| 247 |
+
style_text=style_text,
|
| 248 |
+
style_weight=style_weight,
|
| 249 |
)
|
| 250 |
if skip_start:
|
| 251 |
phones = phones[3:]
|
|
|
|
| 269 |
ja_bert = ja_bert.to(device).unsqueeze(0)
|
| 270 |
en_bert = en_bert.to(device).unsqueeze(0)
|
| 271 |
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
|
| 272 |
+
# emo = emo.to(device).unsqueeze(0)
|
| 273 |
del phones
|
| 274 |
speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
|
| 275 |
audio = (
|
|
|
|
| 282 |
bert,
|
| 283 |
ja_bert,
|
| 284 |
en_bert,
|
|
|
|
| 285 |
sdp_ratio=sdp_ratio,
|
| 286 |
noise_scale=noise_scale,
|
| 287 |
noise_scale_w=noise_scale_w,
|
|
|
|
| 291 |
.float()
|
| 292 |
.numpy()
|
| 293 |
)
|
| 294 |
+
del (
|
| 295 |
+
x_tst,
|
| 296 |
+
tones,
|
| 297 |
+
lang_ids,
|
| 298 |
+
bert,
|
| 299 |
+
x_tst_lengths,
|
| 300 |
+
speakers,
|
| 301 |
+
ja_bert,
|
| 302 |
+
en_bert,
|
| 303 |
+
) # , emo
|
| 304 |
if torch.cuda.is_available():
|
| 305 |
torch.cuda.empty_cache()
|
| 306 |
return audio
|
|
|
|
| 324 |
):
|
| 325 |
bert, ja_bert, en_bert, phones, tones, lang_ids = [], [], [], [], [], []
|
| 326 |
# emo = get_emo_(reference_audio, emotion, sid)
|
| 327 |
+
# if isinstance(reference_audio, np.ndarray):
|
| 328 |
+
# emo = get_clap_audio_feature(reference_audio, device)
|
| 329 |
+
# else:
|
| 330 |
+
# emo = get_clap_text_feature(emotion, device)
|
| 331 |
+
# emo = torch.squeeze(emo, dim=1)
|
| 332 |
for idx, (txt, lang) in enumerate(zip(text, language)):
|
| 333 |
+
_skip_start = (idx != 0) or (skip_start and idx == 0)
|
| 334 |
+
_skip_end = (idx != len(language) - 1) or skip_end
|
| 335 |
(
|
| 336 |
temp_bert,
|
| 337 |
temp_ja_bert,
|
|
|
|
| 340 |
temp_tones,
|
| 341 |
temp_lang_ids,
|
| 342 |
) = get_text(txt, lang, hps, device)
|
| 343 |
+
if _skip_start:
|
| 344 |
temp_bert = temp_bert[:, 3:]
|
| 345 |
temp_ja_bert = temp_ja_bert[:, 3:]
|
| 346 |
temp_en_bert = temp_en_bert[:, 3:]
|
| 347 |
temp_phones = temp_phones[3:]
|
| 348 |
temp_tones = temp_tones[3:]
|
| 349 |
temp_lang_ids = temp_lang_ids[3:]
|
| 350 |
+
if _skip_end:
|
| 351 |
temp_bert = temp_bert[:, :-2]
|
| 352 |
temp_ja_bert = temp_ja_bert[:, :-2]
|
| 353 |
temp_en_bert = temp_en_bert[:, :-2]
|
|
|
|
| 373 |
bert = bert.to(device).unsqueeze(0)
|
| 374 |
ja_bert = ja_bert.to(device).unsqueeze(0)
|
| 375 |
en_bert = en_bert.to(device).unsqueeze(0)
|
| 376 |
+
# emo = emo.to(device).unsqueeze(0)
|
| 377 |
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
|
| 378 |
del phones
|
| 379 |
speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
|
|
|
|
| 387 |
bert,
|
| 388 |
ja_bert,
|
| 389 |
en_bert,
|
|
|
|
| 390 |
sdp_ratio=sdp_ratio,
|
| 391 |
noise_scale=noise_scale,
|
| 392 |
noise_scale_w=noise_scale_w,
|
|
|
|
| 396 |
.float()
|
| 397 |
.numpy()
|
| 398 |
)
|
| 399 |
+
del (
|
| 400 |
+
x_tst,
|
| 401 |
+
tones,
|
| 402 |
+
lang_ids,
|
| 403 |
+
bert,
|
| 404 |
+
x_tst_lengths,
|
| 405 |
+
speakers,
|
| 406 |
+
ja_bert,
|
| 407 |
+
en_bert,
|
| 408 |
+
) # , emo
|
| 409 |
if torch.cuda.is_available():
|
| 410 |
torch.cuda.empty_cache()
|
| 411 |
return audio
|
losses.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
| 1 |
import torch
|
|
|
|
|
|
|
| 2 |
|
| 3 |
|
| 4 |
def feature_loss(fmap_r, fmap_g):
|
|
@@ -56,3 +58,96 @@ def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
|
|
| 56 |
kl = torch.sum(kl * z_mask)
|
| 57 |
l = kl / torch.sum(z_mask)
|
| 58 |
return l
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
+
import torchaudio
|
| 3 |
+
from transformers import AutoModel
|
| 4 |
|
| 5 |
|
| 6 |
def feature_loss(fmap_r, fmap_g):
|
|
|
|
| 58 |
kl = torch.sum(kl * z_mask)
|
| 59 |
l = kl / torch.sum(z_mask)
|
| 60 |
return l
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class WavLMLoss(torch.nn.Module):
|
| 64 |
+
def __init__(self, model, wd, model_sr, slm_sr=16000):
|
| 65 |
+
super(WavLMLoss, self).__init__()
|
| 66 |
+
self.wavlm = AutoModel.from_pretrained(model)
|
| 67 |
+
self.wd = wd
|
| 68 |
+
self.resample = torchaudio.transforms.Resample(model_sr, slm_sr)
|
| 69 |
+
self.wavlm.eval()
|
| 70 |
+
for param in self.wavlm.parameters():
|
| 71 |
+
param.requires_grad = False
|
| 72 |
+
|
| 73 |
+
def forward(self, wav, y_rec):
|
| 74 |
+
with torch.no_grad():
|
| 75 |
+
wav_16 = self.resample(wav)
|
| 76 |
+
wav_embeddings = self.wavlm(
|
| 77 |
+
input_values=wav_16, output_hidden_states=True
|
| 78 |
+
).hidden_states
|
| 79 |
+
y_rec_16 = self.resample(y_rec)
|
| 80 |
+
y_rec_embeddings = self.wavlm(
|
| 81 |
+
input_values=y_rec_16.squeeze(), output_hidden_states=True
|
| 82 |
+
).hidden_states
|
| 83 |
+
|
| 84 |
+
floss = 0
|
| 85 |
+
for er, eg in zip(wav_embeddings, y_rec_embeddings):
|
| 86 |
+
floss += torch.mean(torch.abs(er - eg))
|
| 87 |
+
|
| 88 |
+
return floss.mean()
|
| 89 |
+
|
| 90 |
+
def generator(self, y_rec):
|
| 91 |
+
y_rec_16 = self.resample(y_rec)
|
| 92 |
+
y_rec_embeddings = self.wavlm(
|
| 93 |
+
input_values=y_rec_16, output_hidden_states=True
|
| 94 |
+
).hidden_states
|
| 95 |
+
y_rec_embeddings = (
|
| 96 |
+
torch.stack(y_rec_embeddings, dim=1)
|
| 97 |
+
.transpose(-1, -2)
|
| 98 |
+
.flatten(start_dim=1, end_dim=2)
|
| 99 |
+
)
|
| 100 |
+
y_df_hat_g = self.wd(y_rec_embeddings)
|
| 101 |
+
loss_gen = torch.mean((1 - y_df_hat_g) ** 2)
|
| 102 |
+
|
| 103 |
+
return loss_gen
|
| 104 |
+
|
| 105 |
+
def discriminator(self, wav, y_rec):
|
| 106 |
+
with torch.no_grad():
|
| 107 |
+
wav_16 = self.resample(wav)
|
| 108 |
+
wav_embeddings = self.wavlm(
|
| 109 |
+
input_values=wav_16, output_hidden_states=True
|
| 110 |
+
).hidden_states
|
| 111 |
+
y_rec_16 = self.resample(y_rec)
|
| 112 |
+
y_rec_embeddings = self.wavlm(
|
| 113 |
+
input_values=y_rec_16, output_hidden_states=True
|
| 114 |
+
).hidden_states
|
| 115 |
+
|
| 116 |
+
y_embeddings = (
|
| 117 |
+
torch.stack(wav_embeddings, dim=1)
|
| 118 |
+
.transpose(-1, -2)
|
| 119 |
+
.flatten(start_dim=1, end_dim=2)
|
| 120 |
+
)
|
| 121 |
+
y_rec_embeddings = (
|
| 122 |
+
torch.stack(y_rec_embeddings, dim=1)
|
| 123 |
+
.transpose(-1, -2)
|
| 124 |
+
.flatten(start_dim=1, end_dim=2)
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
y_d_rs = self.wd(y_embeddings)
|
| 128 |
+
y_d_gs = self.wd(y_rec_embeddings)
|
| 129 |
+
|
| 130 |
+
y_df_hat_r, y_df_hat_g = y_d_rs, y_d_gs
|
| 131 |
+
|
| 132 |
+
r_loss = torch.mean((1 - y_df_hat_r) ** 2)
|
| 133 |
+
g_loss = torch.mean((y_df_hat_g) ** 2)
|
| 134 |
+
|
| 135 |
+
loss_disc_f = r_loss + g_loss
|
| 136 |
+
|
| 137 |
+
return loss_disc_f.mean()
|
| 138 |
+
|
| 139 |
+
def discriminator_forward(self, wav):
|
| 140 |
+
with torch.no_grad():
|
| 141 |
+
wav_16 = self.resample(wav)
|
| 142 |
+
wav_embeddings = self.wavlm(
|
| 143 |
+
input_values=wav_16, output_hidden_states=True
|
| 144 |
+
).hidden_states
|
| 145 |
+
y_embeddings = (
|
| 146 |
+
torch.stack(wav_embeddings, dim=1)
|
| 147 |
+
.transpose(-1, -2)
|
| 148 |
+
.flatten(start_dim=1, end_dim=2)
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
y_d_rs = self.wd(y_embeddings)
|
| 152 |
+
|
| 153 |
+
return y_d_rs
|
models.py
CHANGED
|
@@ -40,33 +40,22 @@ class DurationDiscriminator(nn.Module): # vits2
|
|
| 40 |
self.norm_2 = modules.LayerNorm(filter_channels)
|
| 41 |
self.dur_proj = nn.Conv1d(1, filter_channels, 1)
|
| 42 |
|
| 43 |
-
self.
|
| 44 |
-
2 * filter_channels, filter_channels,
|
| 45 |
)
|
| 46 |
-
self.pre_out_norm_1 = modules.LayerNorm(filter_channels)
|
| 47 |
-
self.pre_out_conv_2 = nn.Conv1d(
|
| 48 |
-
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
| 49 |
-
)
|
| 50 |
-
self.pre_out_norm_2 = modules.LayerNorm(filter_channels)
|
| 51 |
|
| 52 |
if gin_channels != 0:
|
| 53 |
self.cond = nn.Conv1d(gin_channels, in_channels, 1)
|
| 54 |
|
| 55 |
-
self.output_layer = nn.Sequential(
|
|
|
|
|
|
|
| 56 |
|
| 57 |
-
def forward_probability(self, x,
|
| 58 |
dur = self.dur_proj(dur)
|
| 59 |
x = torch.cat([x, dur], dim=1)
|
| 60 |
-
x = self.pre_out_conv_1(x * x_mask)
|
| 61 |
-
x = torch.relu(x)
|
| 62 |
-
x = self.pre_out_norm_1(x)
|
| 63 |
-
x = self.drop(x)
|
| 64 |
-
x = self.pre_out_conv_2(x * x_mask)
|
| 65 |
-
x = torch.relu(x)
|
| 66 |
-
x = self.pre_out_norm_2(x)
|
| 67 |
-
x = self.drop(x)
|
| 68 |
-
x = x * x_mask
|
| 69 |
x = x.transpose(1, 2)
|
|
|
|
| 70 |
output_prob = self.output_layer(x)
|
| 71 |
return output_prob
|
| 72 |
|
|
@@ -86,7 +75,7 @@ class DurationDiscriminator(nn.Module): # vits2
|
|
| 86 |
|
| 87 |
output_probs = []
|
| 88 |
for dur in [dur_r, dur_hat]:
|
| 89 |
-
output_prob = self.forward_probability(x,
|
| 90 |
output_probs.append(output_prob)
|
| 91 |
|
| 92 |
return output_probs
|
|
@@ -354,7 +343,6 @@ class TextEncoder(nn.Module):
|
|
| 354 |
n_layers,
|
| 355 |
kernel_size,
|
| 356 |
p_dropout,
|
| 357 |
-
n_speakers,
|
| 358 |
gin_channels=0,
|
| 359 |
):
|
| 360 |
super().__init__()
|
|
@@ -376,31 +364,6 @@ class TextEncoder(nn.Module):
|
|
| 376 |
self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
|
| 377 |
self.ja_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
|
| 378 |
self.en_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
|
| 379 |
-
# self.emo_proj = nn.Linear(512, hidden_channels)
|
| 380 |
-
self.in_feature_net = nn.Sequential(
|
| 381 |
-
# input is assumed to an already normalized embedding
|
| 382 |
-
nn.Linear(512, 1028, bias=False),
|
| 383 |
-
nn.GELU(),
|
| 384 |
-
nn.LayerNorm(1028),
|
| 385 |
-
*[Block(1028, 512) for _ in range(1)],
|
| 386 |
-
nn.Linear(1028, 512, bias=False),
|
| 387 |
-
# normalize before passing to VQ?
|
| 388 |
-
# nn.GELU(),
|
| 389 |
-
# nn.LayerNorm(512),
|
| 390 |
-
)
|
| 391 |
-
self.emo_vq = VectorQuantize(
|
| 392 |
-
dim=512,
|
| 393 |
-
codebook_size=64,
|
| 394 |
-
codebook_dim=32,
|
| 395 |
-
commitment_weight=0.1,
|
| 396 |
-
decay=0.85,
|
| 397 |
-
heads=32,
|
| 398 |
-
kmeans_iters=20,
|
| 399 |
-
separate_codebook_per_head=True,
|
| 400 |
-
stochastic_sample_codes=True,
|
| 401 |
-
threshold_ema_dead_code=2,
|
| 402 |
-
)
|
| 403 |
-
self.out_feature_net = nn.Linear(512, hidden_channels)
|
| 404 |
|
| 405 |
self.encoder = attentions.Encoder(
|
| 406 |
hidden_channels,
|
|
@@ -413,18 +376,10 @@ class TextEncoder(nn.Module):
|
|
| 413 |
)
|
| 414 |
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
| 415 |
|
| 416 |
-
def forward(
|
| 417 |
-
self, x, x_lengths, tone, language, bert, ja_bert, en_bert, emo, sid, g=None
|
| 418 |
-
):
|
| 419 |
-
sid = sid.cpu()
|
| 420 |
bert_emb = self.bert_proj(bert).transpose(1, 2)
|
| 421 |
ja_bert_emb = self.ja_bert_proj(ja_bert).transpose(1, 2)
|
| 422 |
en_bert_emb = self.en_bert_proj(en_bert).transpose(1, 2)
|
| 423 |
-
emo_emb = self.in_feature_net(emo)
|
| 424 |
-
emo_emb, _, loss_commit = self.emo_vq(emo_emb.unsqueeze(1))
|
| 425 |
-
loss_commit = loss_commit.mean()
|
| 426 |
-
emo_emb = self.out_feature_net(emo_emb)
|
| 427 |
-
# emo_emb = self.emo_proj(emo.unsqueeze(1))
|
| 428 |
x = (
|
| 429 |
self.emb(x)
|
| 430 |
+ self.tone_emb(tone)
|
|
@@ -432,7 +387,6 @@ class TextEncoder(nn.Module):
|
|
| 432 |
+ bert_emb
|
| 433 |
+ ja_bert_emb
|
| 434 |
+ en_bert_emb
|
| 435 |
-
+ emo_emb
|
| 436 |
) * math.sqrt(
|
| 437 |
self.hidden_channels
|
| 438 |
) # [b, t, h]
|
|
@@ -445,7 +399,7 @@ class TextEncoder(nn.Module):
|
|
| 445 |
stats = self.proj(x) * x_mask
|
| 446 |
|
| 447 |
m, logs = torch.split(stats, self.out_channels, dim=1)
|
| 448 |
-
return x, m, logs, x_mask
|
| 449 |
|
| 450 |
|
| 451 |
class ResidualCouplingBlock(nn.Module):
|
|
@@ -748,6 +702,55 @@ class MultiPeriodDiscriminator(torch.nn.Module):
|
|
| 748 |
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
| 749 |
|
| 750 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 751 |
class ReferenceEncoder(nn.Module):
|
| 752 |
"""
|
| 753 |
inputs --- [N, Ty/r, n_mels*r] mels
|
|
@@ -878,7 +881,6 @@ class SynthesizerTrn(nn.Module):
|
|
| 878 |
n_layers,
|
| 879 |
kernel_size,
|
| 880 |
p_dropout,
|
| 881 |
-
self.n_speakers,
|
| 882 |
gin_channels=self.enc_gin_channels,
|
| 883 |
)
|
| 884 |
self.dec = Generator(
|
|
@@ -946,14 +948,13 @@ class SynthesizerTrn(nn.Module):
|
|
| 946 |
bert,
|
| 947 |
ja_bert,
|
| 948 |
en_bert,
|
| 949 |
-
emo=None,
|
| 950 |
):
|
| 951 |
if self.n_speakers > 0:
|
| 952 |
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
| 953 |
else:
|
| 954 |
g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
|
| 955 |
-
x, m_p, logs_p, x_mask
|
| 956 |
-
x, x_lengths, tone, language, bert, ja_bert, en_bert,
|
| 957 |
)
|
| 958 |
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
|
| 959 |
z_p = self.flow(z, y_mask, g=g)
|
|
@@ -996,9 +997,11 @@ class SynthesizerTrn(nn.Module):
|
|
| 996 |
|
| 997 |
logw_ = torch.log(w + 1e-6) * x_mask
|
| 998 |
logw = self.dp(x, x_mask, g=g)
|
|
|
|
| 999 |
l_length_dp = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(
|
| 1000 |
x_mask
|
| 1001 |
) # for averaging
|
|
|
|
| 1002 |
|
| 1003 |
l_length = l_length_dp + l_length_sdp
|
| 1004 |
|
|
@@ -1018,9 +1021,8 @@ class SynthesizerTrn(nn.Module):
|
|
| 1018 |
x_mask,
|
| 1019 |
y_mask,
|
| 1020 |
(z, z_p, m_p, logs_p, m_q, logs_q),
|
| 1021 |
-
(x, logw, logw_),
|
| 1022 |
g,
|
| 1023 |
-
loss_commit,
|
| 1024 |
)
|
| 1025 |
|
| 1026 |
def infer(
|
|
@@ -1033,7 +1035,6 @@ class SynthesizerTrn(nn.Module):
|
|
| 1033 |
bert,
|
| 1034 |
ja_bert,
|
| 1035 |
en_bert,
|
| 1036 |
-
emo=None,
|
| 1037 |
noise_scale=0.667,
|
| 1038 |
length_scale=1,
|
| 1039 |
noise_scale_w=0.8,
|
|
@@ -1047,8 +1048,8 @@ class SynthesizerTrn(nn.Module):
|
|
| 1047 |
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
| 1048 |
else:
|
| 1049 |
g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
|
| 1050 |
-
x, m_p, logs_p, x_mask
|
| 1051 |
-
x, x_lengths, tone, language, bert, ja_bert, en_bert,
|
| 1052 |
)
|
| 1053 |
logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * (
|
| 1054 |
sdp_ratio
|
|
|
|
| 40 |
self.norm_2 = modules.LayerNorm(filter_channels)
|
| 41 |
self.dur_proj = nn.Conv1d(1, filter_channels, 1)
|
| 42 |
|
| 43 |
+
self.LSTM = nn.LSTM(
|
| 44 |
+
2 * filter_channels, filter_channels, batch_first=True, bidirectional=True
|
| 45 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
if gin_channels != 0:
|
| 48 |
self.cond = nn.Conv1d(gin_channels, in_channels, 1)
|
| 49 |
|
| 50 |
+
self.output_layer = nn.Sequential(
|
| 51 |
+
nn.Linear(2 * filter_channels, 1), nn.Sigmoid()
|
| 52 |
+
)
|
| 53 |
|
| 54 |
+
def forward_probability(self, x, dur):
|
| 55 |
dur = self.dur_proj(dur)
|
| 56 |
x = torch.cat([x, dur], dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
x = x.transpose(1, 2)
|
| 58 |
+
x, _ = self.LSTM(x)
|
| 59 |
output_prob = self.output_layer(x)
|
| 60 |
return output_prob
|
| 61 |
|
|
|
|
| 75 |
|
| 76 |
output_probs = []
|
| 77 |
for dur in [dur_r, dur_hat]:
|
| 78 |
+
output_prob = self.forward_probability(x, dur)
|
| 79 |
output_probs.append(output_prob)
|
| 80 |
|
| 81 |
return output_probs
|
|
|
|
| 343 |
n_layers,
|
| 344 |
kernel_size,
|
| 345 |
p_dropout,
|
|
|
|
| 346 |
gin_channels=0,
|
| 347 |
):
|
| 348 |
super().__init__()
|
|
|
|
| 364 |
self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
|
| 365 |
self.ja_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
|
| 366 |
self.en_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
|
| 368 |
self.encoder = attentions.Encoder(
|
| 369 |
hidden_channels,
|
|
|
|
| 376 |
)
|
| 377 |
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
| 378 |
|
| 379 |
+
def forward(self, x, x_lengths, tone, language, bert, ja_bert, en_bert, g=None):
|
|
|
|
|
|
|
|
|
|
| 380 |
bert_emb = self.bert_proj(bert).transpose(1, 2)
|
| 381 |
ja_bert_emb = self.ja_bert_proj(ja_bert).transpose(1, 2)
|
| 382 |
en_bert_emb = self.en_bert_proj(en_bert).transpose(1, 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 383 |
x = (
|
| 384 |
self.emb(x)
|
| 385 |
+ self.tone_emb(tone)
|
|
|
|
| 387 |
+ bert_emb
|
| 388 |
+ ja_bert_emb
|
| 389 |
+ en_bert_emb
|
|
|
|
| 390 |
) * math.sqrt(
|
| 391 |
self.hidden_channels
|
| 392 |
) # [b, t, h]
|
|
|
|
| 399 |
stats = self.proj(x) * x_mask
|
| 400 |
|
| 401 |
m, logs = torch.split(stats, self.out_channels, dim=1)
|
| 402 |
+
return x, m, logs, x_mask
|
| 403 |
|
| 404 |
|
| 405 |
class ResidualCouplingBlock(nn.Module):
|
|
|
|
| 702 |
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
| 703 |
|
| 704 |
|
| 705 |
+
class WavLMDiscriminator(nn.Module):
|
| 706 |
+
"""docstring for Discriminator."""
|
| 707 |
+
|
| 708 |
+
def __init__(
|
| 709 |
+
self, slm_hidden=768, slm_layers=13, initial_channel=64, use_spectral_norm=False
|
| 710 |
+
):
|
| 711 |
+
super(WavLMDiscriminator, self).__init__()
|
| 712 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
| 713 |
+
self.pre = norm_f(
|
| 714 |
+
Conv1d(slm_hidden * slm_layers, initial_channel, 1, 1, padding=0)
|
| 715 |
+
)
|
| 716 |
+
|
| 717 |
+
self.convs = nn.ModuleList(
|
| 718 |
+
[
|
| 719 |
+
norm_f(
|
| 720 |
+
nn.Conv1d(
|
| 721 |
+
initial_channel, initial_channel * 2, kernel_size=5, padding=2
|
| 722 |
+
)
|
| 723 |
+
),
|
| 724 |
+
norm_f(
|
| 725 |
+
nn.Conv1d(
|
| 726 |
+
initial_channel * 2,
|
| 727 |
+
initial_channel * 4,
|
| 728 |
+
kernel_size=5,
|
| 729 |
+
padding=2,
|
| 730 |
+
)
|
| 731 |
+
),
|
| 732 |
+
norm_f(
|
| 733 |
+
nn.Conv1d(initial_channel * 4, initial_channel * 4, 5, 1, padding=2)
|
| 734 |
+
),
|
| 735 |
+
]
|
| 736 |
+
)
|
| 737 |
+
|
| 738 |
+
self.conv_post = norm_f(Conv1d(initial_channel * 4, 1, 3, 1, padding=1))
|
| 739 |
+
|
| 740 |
+
def forward(self, x):
|
| 741 |
+
x = self.pre(x)
|
| 742 |
+
|
| 743 |
+
fmap = []
|
| 744 |
+
for l in self.convs:
|
| 745 |
+
x = l(x)
|
| 746 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
| 747 |
+
fmap.append(x)
|
| 748 |
+
x = self.conv_post(x)
|
| 749 |
+
x = torch.flatten(x, 1, -1)
|
| 750 |
+
|
| 751 |
+
return x
|
| 752 |
+
|
| 753 |
+
|
| 754 |
class ReferenceEncoder(nn.Module):
|
| 755 |
"""
|
| 756 |
inputs --- [N, Ty/r, n_mels*r] mels
|
|
|
|
| 881 |
n_layers,
|
| 882 |
kernel_size,
|
| 883 |
p_dropout,
|
|
|
|
| 884 |
gin_channels=self.enc_gin_channels,
|
| 885 |
)
|
| 886 |
self.dec = Generator(
|
|
|
|
| 948 |
bert,
|
| 949 |
ja_bert,
|
| 950 |
en_bert,
|
|
|
|
| 951 |
):
|
| 952 |
if self.n_speakers > 0:
|
| 953 |
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
| 954 |
else:
|
| 955 |
g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
|
| 956 |
+
x, m_p, logs_p, x_mask = self.enc_p(
|
| 957 |
+
x, x_lengths, tone, language, bert, ja_bert, en_bert, g=g
|
| 958 |
)
|
| 959 |
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
|
| 960 |
z_p = self.flow(z, y_mask, g=g)
|
|
|
|
| 997 |
|
| 998 |
logw_ = torch.log(w + 1e-6) * x_mask
|
| 999 |
logw = self.dp(x, x_mask, g=g)
|
| 1000 |
+
logw_sdp = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=1.0)
|
| 1001 |
l_length_dp = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(
|
| 1002 |
x_mask
|
| 1003 |
) # for averaging
|
| 1004 |
+
l_length_sdp += torch.sum((logw_sdp - logw_) ** 2, [1, 2]) / torch.sum(x_mask)
|
| 1005 |
|
| 1006 |
l_length = l_length_dp + l_length_sdp
|
| 1007 |
|
|
|
|
| 1021 |
x_mask,
|
| 1022 |
y_mask,
|
| 1023 |
(z, z_p, m_p, logs_p, m_q, logs_q),
|
| 1024 |
+
(x, logw, logw_, logw_sdp),
|
| 1025 |
g,
|
|
|
|
| 1026 |
)
|
| 1027 |
|
| 1028 |
def infer(
|
|
|
|
| 1035 |
bert,
|
| 1036 |
ja_bert,
|
| 1037 |
en_bert,
|
|
|
|
| 1038 |
noise_scale=0.667,
|
| 1039 |
length_scale=1,
|
| 1040 |
noise_scale_w=0.8,
|
|
|
|
| 1048 |
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
| 1049 |
else:
|
| 1050 |
g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
|
| 1051 |
+
x, m_p, logs_p, x_mask = self.enc_p(
|
| 1052 |
+
x, x_lengths, tone, language, bert, ja_bert, en_bert, g=g
|
| 1053 |
)
|
| 1054 |
logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * (
|
| 1055 |
sdp_ratio
|
oldVersion/V210/__init__.py
CHANGED
|
@@ -5,10 +5,9 @@ import torch
|
|
| 5 |
import commons
|
| 6 |
from .text import cleaned_text_to_sequence, get_bert
|
| 7 |
from .text.cleaner import clean_text
|
| 8 |
-
from .emo_gen import get_emo
|
| 9 |
|
| 10 |
|
| 11 |
-
def get_text(text, language_str, hps, device):
|
| 12 |
# 在此处实现当前版本的get_text
|
| 13 |
norm_text, phone, tone, word2ph = clean_text(text, language_str)
|
| 14 |
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
|
|
@@ -20,7 +19,9 @@ def get_text(text, language_str, hps, device):
|
|
| 20 |
for i in range(len(word2ph)):
|
| 21 |
word2ph[i] = word2ph[i] * 2
|
| 22 |
word2ph[0] += 1
|
| 23 |
-
bert_ori = get_bert(
|
|
|
|
|
|
|
| 24 |
del word2ph
|
| 25 |
assert bert_ori.shape[-1] == len(phone), phone
|
| 26 |
|
|
@@ -50,6 +51,8 @@ def get_text(text, language_str, hps, device):
|
|
| 50 |
|
| 51 |
|
| 52 |
def get_emo_(reference_audio, emotion):
|
|
|
|
|
|
|
| 53 |
emo = (
|
| 54 |
torch.from_numpy(get_emo(reference_audio))
|
| 55 |
if reference_audio
|
|
@@ -73,9 +76,11 @@ def infer(
|
|
| 73 |
emotion=None,
|
| 74 |
skip_start=False,
|
| 75 |
skip_end=False,
|
|
|
|
|
|
|
| 76 |
):
|
| 77 |
bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
|
| 78 |
-
text, language, hps, device
|
| 79 |
)
|
| 80 |
emo = get_emo_(reference_audio, emotion)
|
| 81 |
if skip_start:
|
|
|
|
| 5 |
import commons
|
| 6 |
from .text import cleaned_text_to_sequence, get_bert
|
| 7 |
from .text.cleaner import clean_text
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
+
def get_text(text, language_str, hps, device, style_text=None, style_weight=0.7):
|
| 11 |
# 在此处实现当前版本的get_text
|
| 12 |
norm_text, phone, tone, word2ph = clean_text(text, language_str)
|
| 13 |
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
|
|
|
|
| 19 |
for i in range(len(word2ph)):
|
| 20 |
word2ph[i] = word2ph[i] * 2
|
| 21 |
word2ph[0] += 1
|
| 22 |
+
bert_ori = get_bert(
|
| 23 |
+
norm_text, word2ph, language_str, device, style_text, style_weight
|
| 24 |
+
)
|
| 25 |
del word2ph
|
| 26 |
assert bert_ori.shape[-1] == len(phone), phone
|
| 27 |
|
|
|
|
| 51 |
|
| 52 |
|
| 53 |
def get_emo_(reference_audio, emotion):
|
| 54 |
+
from .emo_gen import get_emo
|
| 55 |
+
|
| 56 |
emo = (
|
| 57 |
torch.from_numpy(get_emo(reference_audio))
|
| 58 |
if reference_audio
|
|
|
|
| 76 |
emotion=None,
|
| 77 |
skip_start=False,
|
| 78 |
skip_end=False,
|
| 79 |
+
style_text=None,
|
| 80 |
+
style_weight=0.7,
|
| 81 |
):
|
| 82 |
bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
|
| 83 |
+
text, language, hps, device, style_text, style_weight
|
| 84 |
)
|
| 85 |
emo = get_emo_(reference_audio, emotion)
|
| 86 |
if skip_start:
|
oldVersion/V210/models.py
CHANGED
|
@@ -13,7 +13,7 @@ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
|
| 13 |
from vector_quantize_pytorch import VectorQuantize
|
| 14 |
|
| 15 |
from commons import init_weights, get_padding
|
| 16 |
-
from text import symbols, num_tones, num_languages
|
| 17 |
|
| 18 |
|
| 19 |
class DurationDiscriminator(nn.Module): # vits2
|
|
|
|
| 13 |
from vector_quantize_pytorch import VectorQuantize
|
| 14 |
|
| 15 |
from commons import init_weights, get_padding
|
| 16 |
+
from .text import symbols, num_tones, num_languages
|
| 17 |
|
| 18 |
|
| 19 |
class DurationDiscriminator(nn.Module): # vits2
|
oldVersion/V210/text/__init__.py
CHANGED
|
@@ -18,13 +18,15 @@ def cleaned_text_to_sequence(cleaned_text, tones, language):
|
|
| 18 |
return phones, tones, lang_ids
|
| 19 |
|
| 20 |
|
| 21 |
-
def get_bert(norm_text, word2ph, language, device):
|
| 22 |
from .chinese_bert import get_bert_feature as zh_bert
|
| 23 |
from .english_bert_mock import get_bert_feature as en_bert
|
| 24 |
from .japanese_bert import get_bert_feature as jp_bert
|
| 25 |
|
| 26 |
lang_bert_func_map = {"ZH": zh_bert, "EN": en_bert, "JP": jp_bert}
|
| 27 |
-
bert = lang_bert_func_map[language](
|
|
|
|
|
|
|
| 28 |
return bert
|
| 29 |
|
| 30 |
|
|
|
|
| 18 |
return phones, tones, lang_ids
|
| 19 |
|
| 20 |
|
| 21 |
+
def get_bert(norm_text, word2ph, language, device, style_text, style_weight):
|
| 22 |
from .chinese_bert import get_bert_feature as zh_bert
|
| 23 |
from .english_bert_mock import get_bert_feature as en_bert
|
| 24 |
from .japanese_bert import get_bert_feature as jp_bert
|
| 25 |
|
| 26 |
lang_bert_func_map = {"ZH": zh_bert, "EN": en_bert, "JP": jp_bert}
|
| 27 |
+
bert = lang_bert_func_map[language](
|
| 28 |
+
norm_text, word2ph, device, style_text, style_weight
|
| 29 |
+
)
|
| 30 |
return bert
|
| 31 |
|
| 32 |
|
oldVersion/V210/text/chinese_bert.py
CHANGED
|
@@ -12,7 +12,13 @@ tokenizer = AutoTokenizer.from_pretrained(LOCAL_PATH)
|
|
| 12 |
models = dict()
|
| 13 |
|
| 14 |
|
| 15 |
-
def get_bert_feature(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
if (
|
| 17 |
sys.platform == "darwin"
|
| 18 |
and torch.backends.mps.is_available()
|
|
@@ -29,12 +35,25 @@ def get_bert_feature(text, word2ph, device=config.bert_gen_config.device):
|
|
| 29 |
inputs[i] = inputs[i].to(device)
|
| 30 |
res = models[device](**inputs, output_hidden_states=True)
|
| 31 |
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
assert len(word2ph) == len(text) + 2
|
| 34 |
word2phone = word2ph
|
| 35 |
phone_level_feature = []
|
| 36 |
for i in range(len(word2phone)):
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
phone_level_feature.append(repeat_feature)
|
| 39 |
|
| 40 |
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
|
|
|
| 12 |
models = dict()
|
| 13 |
|
| 14 |
|
| 15 |
+
def get_bert_feature(
|
| 16 |
+
text,
|
| 17 |
+
word2ph,
|
| 18 |
+
device=config.bert_gen_config.device,
|
| 19 |
+
style_text=None,
|
| 20 |
+
style_weight=0.7,
|
| 21 |
+
):
|
| 22 |
if (
|
| 23 |
sys.platform == "darwin"
|
| 24 |
and torch.backends.mps.is_available()
|
|
|
|
| 35 |
inputs[i] = inputs[i].to(device)
|
| 36 |
res = models[device](**inputs, output_hidden_states=True)
|
| 37 |
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
|
| 38 |
+
if style_text:
|
| 39 |
+
style_inputs = tokenizer(style_text, return_tensors="pt")
|
| 40 |
+
for i in style_inputs:
|
| 41 |
+
style_inputs[i] = style_inputs[i].to(device)
|
| 42 |
+
style_res = models[device](**style_inputs, output_hidden_states=True)
|
| 43 |
+
style_res = torch.cat(style_res["hidden_states"][-3:-2], -1)[0].cpu()
|
| 44 |
+
style_res_mean = style_res.mean(0)
|
| 45 |
|
| 46 |
assert len(word2ph) == len(text) + 2
|
| 47 |
word2phone = word2ph
|
| 48 |
phone_level_feature = []
|
| 49 |
for i in range(len(word2phone)):
|
| 50 |
+
if style_text:
|
| 51 |
+
repeat_feature = (
|
| 52 |
+
res[i].repeat(word2phone[i], 1) * (1 - style_weight)
|
| 53 |
+
+ style_res_mean.repeat(word2phone[i], 1) * style_weight
|
| 54 |
+
)
|
| 55 |
+
else:
|
| 56 |
+
repeat_feature = res[i].repeat(word2phone[i], 1)
|
| 57 |
phone_level_feature.append(repeat_feature)
|
| 58 |
|
| 59 |
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
oldVersion/V210/text/english_bert_mock.py
CHANGED
|
@@ -13,7 +13,13 @@ tokenizer = DebertaV2Tokenizer.from_pretrained(LOCAL_PATH)
|
|
| 13 |
models = dict()
|
| 14 |
|
| 15 |
|
| 16 |
-
def get_bert_feature(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
if (
|
| 18 |
sys.platform == "darwin"
|
| 19 |
and torch.backends.mps.is_available()
|
|
@@ -30,11 +36,24 @@ def get_bert_feature(text, word2ph, device=config.bert_gen_config.device):
|
|
| 30 |
inputs[i] = inputs[i].to(device)
|
| 31 |
res = models[device](**inputs, output_hidden_states=True)
|
| 32 |
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
assert len(word2ph) == res.shape[0], (text, res.shape[0], len(word2ph))
|
| 34 |
word2phone = word2ph
|
| 35 |
phone_level_feature = []
|
| 36 |
for i in range(len(word2phone)):
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
phone_level_feature.append(repeat_feature)
|
| 39 |
|
| 40 |
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
|
|
|
| 13 |
models = dict()
|
| 14 |
|
| 15 |
|
| 16 |
+
def get_bert_feature(
|
| 17 |
+
text,
|
| 18 |
+
word2ph,
|
| 19 |
+
device=config.bert_gen_config.device,
|
| 20 |
+
style_text=None,
|
| 21 |
+
style_weight=0.7,
|
| 22 |
+
):
|
| 23 |
if (
|
| 24 |
sys.platform == "darwin"
|
| 25 |
and torch.backends.mps.is_available()
|
|
|
|
| 36 |
inputs[i] = inputs[i].to(device)
|
| 37 |
res = models[device](**inputs, output_hidden_states=True)
|
| 38 |
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
|
| 39 |
+
if style_text:
|
| 40 |
+
style_inputs = tokenizer(style_text, return_tensors="pt")
|
| 41 |
+
for i in style_inputs:
|
| 42 |
+
style_inputs[i] = style_inputs[i].to(device)
|
| 43 |
+
style_res = models[device](**style_inputs, output_hidden_states=True)
|
| 44 |
+
style_res = torch.cat(style_res["hidden_states"][-3:-2], -1)[0].cpu()
|
| 45 |
+
style_res_mean = style_res.mean(0)
|
| 46 |
assert len(word2ph) == res.shape[0], (text, res.shape[0], len(word2ph))
|
| 47 |
word2phone = word2ph
|
| 48 |
phone_level_feature = []
|
| 49 |
for i in range(len(word2phone)):
|
| 50 |
+
if style_text:
|
| 51 |
+
repeat_feature = (
|
| 52 |
+
res[i].repeat(word2phone[i], 1) * (1 - style_weight)
|
| 53 |
+
+ style_res_mean.repeat(word2phone[i], 1) * style_weight
|
| 54 |
+
)
|
| 55 |
+
else:
|
| 56 |
+
repeat_feature = res[i].repeat(word2phone[i], 1)
|
| 57 |
phone_level_feature.append(repeat_feature)
|
| 58 |
|
| 59 |
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
oldVersion/V210/text/japanese_bert.py
CHANGED
|
@@ -13,8 +13,16 @@ tokenizer = AutoTokenizer.from_pretrained(LOCAL_PATH)
|
|
| 13 |
models = dict()
|
| 14 |
|
| 15 |
|
| 16 |
-
def get_bert_feature(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
text = "".join(text2sep_kata(text)[0])
|
|
|
|
|
|
|
| 18 |
if (
|
| 19 |
sys.platform == "darwin"
|
| 20 |
and torch.backends.mps.is_available()
|
|
@@ -31,12 +39,25 @@ def get_bert_feature(text, word2ph, device=config.bert_gen_config.device):
|
|
| 31 |
inputs[i] = inputs[i].to(device)
|
| 32 |
res = models[device](**inputs, output_hidden_states=True)
|
| 33 |
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
assert len(word2ph) == len(text) + 2
|
| 36 |
word2phone = word2ph
|
| 37 |
phone_level_feature = []
|
| 38 |
for i in range(len(word2phone)):
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
phone_level_feature.append(repeat_feature)
|
| 41 |
|
| 42 |
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
|
|
|
| 13 |
models = dict()
|
| 14 |
|
| 15 |
|
| 16 |
+
def get_bert_feature(
|
| 17 |
+
text,
|
| 18 |
+
word2ph,
|
| 19 |
+
device=config.bert_gen_config.device,
|
| 20 |
+
style_text=None,
|
| 21 |
+
style_weight=0.7,
|
| 22 |
+
):
|
| 23 |
text = "".join(text2sep_kata(text)[0])
|
| 24 |
+
if style_text:
|
| 25 |
+
style_text = "".join(text2sep_kata(style_text)[0])
|
| 26 |
if (
|
| 27 |
sys.platform == "darwin"
|
| 28 |
and torch.backends.mps.is_available()
|
|
|
|
| 39 |
inputs[i] = inputs[i].to(device)
|
| 40 |
res = models[device](**inputs, output_hidden_states=True)
|
| 41 |
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
|
| 42 |
+
if style_text:
|
| 43 |
+
style_inputs = tokenizer(style_text, return_tensors="pt")
|
| 44 |
+
for i in style_inputs:
|
| 45 |
+
style_inputs[i] = style_inputs[i].to(device)
|
| 46 |
+
style_res = models[device](**style_inputs, output_hidden_states=True)
|
| 47 |
+
style_res = torch.cat(style_res["hidden_states"][-3:-2], -1)[0].cpu()
|
| 48 |
+
style_res_mean = style_res.mean(0)
|
| 49 |
|
| 50 |
assert len(word2ph) == len(text) + 2
|
| 51 |
word2phone = word2ph
|
| 52 |
phone_level_feature = []
|
| 53 |
for i in range(len(word2phone)):
|
| 54 |
+
if style_text:
|
| 55 |
+
repeat_feature = (
|
| 56 |
+
res[i].repeat(word2phone[i], 1) * (1 - style_weight)
|
| 57 |
+
+ style_res_mean.repeat(word2phone[i], 1) * style_weight
|
| 58 |
+
)
|
| 59 |
+
else:
|
| 60 |
+
repeat_feature = res[i].repeat(word2phone[i], 1)
|
| 61 |
phone_level_feature.append(repeat_feature)
|
| 62 |
|
| 63 |
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
onnx_infer.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from onnx_modules.V220_OnnxInference import OnnxInferenceSession
|
| 2 |
+
import numpy as np
|
| 3 |
+
Session = OnnxInferenceSession(
|
| 4 |
+
{
|
| 5 |
+
"enc" : "onnx/BertVits2.2PT/BertVits2.2PT_enc_p.onnx",
|
| 6 |
+
"emb_g" : "onnx/BertVits2.2PT/BertVits2.2PT_emb.onnx",
|
| 7 |
+
"dp" : "onnx/BertVits2.2PT/BertVits2.2PT_dp.onnx",
|
| 8 |
+
"sdp" : "onnx/BertVits2.2PT/BertVits2.2PT_sdp.onnx",
|
| 9 |
+
"flow" : "onnx/BertVits2.2PT/BertVits2.2PT_flow.onnx",
|
| 10 |
+
"dec" : "onnx/BertVits2.2PT/BertVits2.2PT_dec.onnx"
|
| 11 |
+
},
|
| 12 |
+
Providers = ["CPUExecutionProvider"]
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
#这里的输入和原版是一样的,只需要在原版预处理结果出来之后加上.numpy()即可
|
| 16 |
+
x = np.array(
|
| 17 |
+
[
|
| 18 |
+
0,
|
| 19 |
+
97,
|
| 20 |
+
0,
|
| 21 |
+
8,
|
| 22 |
+
0,
|
| 23 |
+
78,
|
| 24 |
+
0,
|
| 25 |
+
8,
|
| 26 |
+
0,
|
| 27 |
+
76,
|
| 28 |
+
0,
|
| 29 |
+
37,
|
| 30 |
+
0,
|
| 31 |
+
40,
|
| 32 |
+
0,
|
| 33 |
+
97,
|
| 34 |
+
0,
|
| 35 |
+
8,
|
| 36 |
+
0,
|
| 37 |
+
23,
|
| 38 |
+
0,
|
| 39 |
+
8,
|
| 40 |
+
0,
|
| 41 |
+
74,
|
| 42 |
+
0,
|
| 43 |
+
26,
|
| 44 |
+
0,
|
| 45 |
+
104,
|
| 46 |
+
0,
|
| 47 |
+
]
|
| 48 |
+
)
|
| 49 |
+
tone = np.zeros_like(x)
|
| 50 |
+
language = np.zeros_like(x)
|
| 51 |
+
sid = np.array([0])
|
| 52 |
+
bert = np.random.randn(x.shape[0], 1024)
|
| 53 |
+
ja_bert = np.random.randn(x.shape[0], 1024)
|
| 54 |
+
en_bert = np.random.randn(x.shape[0], 1024)
|
| 55 |
+
emo = np.random.randn(512, 1)
|
| 56 |
+
|
| 57 |
+
audio = Session(
|
| 58 |
+
x,
|
| 59 |
+
tone,
|
| 60 |
+
language,
|
| 61 |
+
bert,
|
| 62 |
+
ja_bert,
|
| 63 |
+
en_bert,
|
| 64 |
+
emo,
|
| 65 |
+
sid
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
print(audio)
|
onnx_modules/V200/__init__.py
CHANGED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .text.symbols import symbols
|
| 2 |
+
from .models_onnx import SynthesizerTrn
|
| 3 |
+
|
| 4 |
+
__all__ = ["symbols", "SynthesizerTrn"]
|
onnx_modules/V200_OnnxInference/__init__.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import onnxruntime as ort
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def convert_pad_shape(pad_shape):
|
| 6 |
+
layer = pad_shape[::-1]
|
| 7 |
+
pad_shape = [item for sublist in layer for item in sublist]
|
| 8 |
+
return pad_shape
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def sequence_mask(length, max_length=None):
|
| 12 |
+
if max_length is None:
|
| 13 |
+
max_length = length.max()
|
| 14 |
+
x = np.arange(max_length, dtype=length.dtype)
|
| 15 |
+
return np.expand_dims(x, 0) < np.expand_dims(length, 1)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def generate_path(duration, mask):
|
| 19 |
+
"""
|
| 20 |
+
duration: [b, 1, t_x]
|
| 21 |
+
mask: [b, 1, t_y, t_x]
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
b, _, t_y, t_x = mask.shape
|
| 25 |
+
cum_duration = np.cumsum(duration, -1)
|
| 26 |
+
|
| 27 |
+
cum_duration_flat = cum_duration.reshape(b * t_x)
|
| 28 |
+
path = sequence_mask(cum_duration_flat, t_y)
|
| 29 |
+
path = path.reshape(b, t_x, t_y)
|
| 30 |
+
path = path ^ np.pad(path, ((0, 0), (1, 0), (0, 0)))[:, :-1]
|
| 31 |
+
path = np.expand_dims(path, 1).transpose(0, 1, 3, 2)
|
| 32 |
+
return path
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class OnnxInferenceSession:
|
| 36 |
+
def __init__(self, path, Providers=["CPUExecutionProvider"]):
|
| 37 |
+
self.enc = ort.InferenceSession(path["enc"], providers=Providers)
|
| 38 |
+
self.emb_g = ort.InferenceSession(path["emb_g"], providers=Providers)
|
| 39 |
+
self.dp = ort.InferenceSession(path["dp"], providers=Providers)
|
| 40 |
+
self.sdp = ort.InferenceSession(path["sdp"], providers=Providers)
|
| 41 |
+
self.flow = ort.InferenceSession(path["flow"], providers=Providers)
|
| 42 |
+
self.dec = ort.InferenceSession(path["dec"], providers=Providers)
|
| 43 |
+
|
| 44 |
+
def __call__(
|
| 45 |
+
self,
|
| 46 |
+
seq,
|
| 47 |
+
tone,
|
| 48 |
+
language,
|
| 49 |
+
bert_zh,
|
| 50 |
+
bert_jp,
|
| 51 |
+
bert_en,
|
| 52 |
+
sid,
|
| 53 |
+
seed=114514,
|
| 54 |
+
seq_noise_scale=0.8,
|
| 55 |
+
sdp_noise_scale=0.6,
|
| 56 |
+
length_scale=1.0,
|
| 57 |
+
sdp_ratio=0.0,
|
| 58 |
+
):
|
| 59 |
+
if seq.ndim == 1:
|
| 60 |
+
seq = np.expand_dims(seq, 0)
|
| 61 |
+
if tone.ndim == 1:
|
| 62 |
+
tone = np.expand_dims(tone, 0)
|
| 63 |
+
if language.ndim == 1:
|
| 64 |
+
language = np.expand_dims(language, 0)
|
| 65 |
+
assert(seq.ndim == 2,tone.ndim == 2,language.ndim == 2)
|
| 66 |
+
g = self.emb_g.run(
|
| 67 |
+
None,
|
| 68 |
+
{
|
| 69 |
+
"sid": sid.astype(np.int64),
|
| 70 |
+
},
|
| 71 |
+
)[0]
|
| 72 |
+
g = np.expand_dims(g, -1)
|
| 73 |
+
enc_rtn = self.enc.run(
|
| 74 |
+
None,
|
| 75 |
+
{
|
| 76 |
+
"x": seq.astype(np.int64),
|
| 77 |
+
"t": tone.astype(np.int64),
|
| 78 |
+
"language": language.astype(np.int64),
|
| 79 |
+
"bert_0": bert_zh.astype(np.float32),
|
| 80 |
+
"bert_1": bert_jp.astype(np.float32),
|
| 81 |
+
"bert_2": bert_en.astype(np.float32),
|
| 82 |
+
"g": g.astype(np.float32),
|
| 83 |
+
},
|
| 84 |
+
)
|
| 85 |
+
x, m_p, logs_p, x_mask = enc_rtn[0], enc_rtn[1], enc_rtn[2], enc_rtn[3]
|
| 86 |
+
np.random.seed(seed)
|
| 87 |
+
zinput = np.random.randn(x.shape[0], 2, x.shape[2]) * sdp_noise_scale
|
| 88 |
+
logw = self.sdp.run(
|
| 89 |
+
None, {"x": x, "x_mask": x_mask, "zin": zinput.astype(np.float32), "g": g}
|
| 90 |
+
)[0] * (sdp_ratio) + self.dp.run(None, {"x": x, "x_mask": x_mask, "g": g})[
|
| 91 |
+
0
|
| 92 |
+
] * (
|
| 93 |
+
1 - sdp_ratio
|
| 94 |
+
)
|
| 95 |
+
w = np.exp(logw) * x_mask * length_scale
|
| 96 |
+
w_ceil = np.ceil(w)
|
| 97 |
+
y_lengths = np.clip(np.sum(w_ceil, (1, 2)), a_min=1.0, a_max=100000).astype(
|
| 98 |
+
np.int64
|
| 99 |
+
)
|
| 100 |
+
y_mask = np.expand_dims(sequence_mask(y_lengths, None), 1)
|
| 101 |
+
attn_mask = np.expand_dims(x_mask, 2) * np.expand_dims(y_mask, -1)
|
| 102 |
+
attn = generate_path(w_ceil, attn_mask)
|
| 103 |
+
m_p = np.matmul(attn.squeeze(1), m_p.transpose(0, 2, 1)).transpose(
|
| 104 |
+
0, 2, 1
|
| 105 |
+
) # [b, t', t], [b, t, d] -> [b, d, t']
|
| 106 |
+
logs_p = np.matmul(attn.squeeze(1), logs_p.transpose(0, 2, 1)).transpose(
|
| 107 |
+
0, 2, 1
|
| 108 |
+
) # [b, t', t], [b, t, d] -> [b, d, t']
|
| 109 |
+
|
| 110 |
+
z_p = (
|
| 111 |
+
m_p
|
| 112 |
+
+ np.random.randn(m_p.shape[0], m_p.shape[1], m_p.shape[2])
|
| 113 |
+
* np.exp(logs_p)
|
| 114 |
+
* seq_noise_scale
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
z = self.flow.run(
|
| 118 |
+
None,
|
| 119 |
+
{
|
| 120 |
+
"z_p": z_p.astype(np.float32),
|
| 121 |
+
"y_mask": y_mask.astype(np.float32),
|
| 122 |
+
"g": g,
|
| 123 |
+
},
|
| 124 |
+
)[0]
|
| 125 |
+
|
| 126 |
+
return self.dec.run(None, {"z_in": z.astype(np.float32), "g": g})[0]
|
onnx_modules/V210/__init__.py
CHANGED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .text.symbols import symbols
|
| 2 |
+
from .models_onnx import SynthesizerTrn
|
| 3 |
+
|
| 4 |
+
__all__ = ["symbols", "SynthesizerTrn"]
|
onnx_modules/V210/models_onnx.py
CHANGED
|
@@ -942,7 +942,7 @@ class SynthesizerTrn(nn.Module):
|
|
| 942 |
|
| 943 |
torch.onnx.export(
|
| 944 |
self.enc_p,
|
| 945 |
-
(x, x_lengths, tone, language, bert, ja_bert, en_bert, g, sid
|
| 946 |
f"onnx/{path}/{path}_enc_p.onnx",
|
| 947 |
input_names=[
|
| 948 |
"x",
|
|
|
|
| 942 |
|
| 943 |
torch.onnx.export(
|
| 944 |
self.enc_p,
|
| 945 |
+
(x, x_lengths, tone, language, bert, ja_bert, en_bert, g, sid, sid),
|
| 946 |
f"onnx/{path}/{path}_enc_p.onnx",
|
| 947 |
input_names=[
|
| 948 |
"x",
|
onnx_modules/V210_OnnxInference/__init__.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import onnxruntime as ort
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def convert_pad_shape(pad_shape):
|
| 6 |
+
layer = pad_shape[::-1]
|
| 7 |
+
pad_shape = [item for sublist in layer for item in sublist]
|
| 8 |
+
return pad_shape
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def sequence_mask(length, max_length=None):
|
| 12 |
+
if max_length is None:
|
| 13 |
+
max_length = length.max()
|
| 14 |
+
x = np.arange(max_length, dtype=length.dtype)
|
| 15 |
+
return np.expand_dims(x, 0) < np.expand_dims(length, 1)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def generate_path(duration, mask):
|
| 19 |
+
"""
|
| 20 |
+
duration: [b, 1, t_x]
|
| 21 |
+
mask: [b, 1, t_y, t_x]
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
b, _, t_y, t_x = mask.shape
|
| 25 |
+
cum_duration = np.cumsum(duration, -1)
|
| 26 |
+
|
| 27 |
+
cum_duration_flat = cum_duration.reshape(b * t_x)
|
| 28 |
+
path = sequence_mask(cum_duration_flat, t_y)
|
| 29 |
+
path = path.reshape(b, t_x, t_y)
|
| 30 |
+
path = path ^ np.pad(path, ((0, 0), (1, 0), (0, 0)))[:, :-1]
|
| 31 |
+
path = np.expand_dims(path, 1).transpose(0, 1, 3, 2)
|
| 32 |
+
return path
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class OnnxInferenceSession:
|
| 36 |
+
def __init__(self, path, Providers=["CPUExecutionProvider"]):
|
| 37 |
+
self.enc = ort.InferenceSession(path["enc"], providers=Providers)
|
| 38 |
+
self.emb_g = ort.InferenceSession(path["emb_g"], providers=Providers)
|
| 39 |
+
self.dp = ort.InferenceSession(path["dp"], providers=Providers)
|
| 40 |
+
self.sdp = ort.InferenceSession(path["sdp"], providers=Providers)
|
| 41 |
+
self.flow = ort.InferenceSession(path["flow"], providers=Providers)
|
| 42 |
+
self.dec = ort.InferenceSession(path["dec"], providers=Providers)
|
| 43 |
+
|
| 44 |
+
def __call__(
|
| 45 |
+
self,
|
| 46 |
+
seq,
|
| 47 |
+
tone,
|
| 48 |
+
language,
|
| 49 |
+
bert_zh,
|
| 50 |
+
bert_jp,
|
| 51 |
+
bert_en,
|
| 52 |
+
vqidx,
|
| 53 |
+
sid,
|
| 54 |
+
seed=114514,
|
| 55 |
+
seq_noise_scale=0.8,
|
| 56 |
+
sdp_noise_scale=0.6,
|
| 57 |
+
length_scale=1.0,
|
| 58 |
+
sdp_ratio=0.0,
|
| 59 |
+
):
|
| 60 |
+
if seq.ndim == 1:
|
| 61 |
+
seq = np.expand_dims(seq, 0)
|
| 62 |
+
if tone.ndim == 1:
|
| 63 |
+
tone = np.expand_dims(tone, 0)
|
| 64 |
+
if language.ndim == 1:
|
| 65 |
+
language = np.expand_dims(language, 0)
|
| 66 |
+
assert(seq.ndim == 2,tone.ndim == 2,language.ndim == 2)
|
| 67 |
+
g = self.emb_g.run(
|
| 68 |
+
None,
|
| 69 |
+
{
|
| 70 |
+
"sid": sid.astype(np.int64),
|
| 71 |
+
},
|
| 72 |
+
)[0]
|
| 73 |
+
g = np.expand_dims(g, -1)
|
| 74 |
+
enc_rtn = self.enc.run(
|
| 75 |
+
None,
|
| 76 |
+
{
|
| 77 |
+
"x": seq.astype(np.int64),
|
| 78 |
+
"t": tone.astype(np.int64),
|
| 79 |
+
"language": language.astype(np.int64),
|
| 80 |
+
"bert_0": bert_zh.astype(np.float32),
|
| 81 |
+
"bert_1": bert_jp.astype(np.float32),
|
| 82 |
+
"bert_2": bert_en.astype(np.float32),
|
| 83 |
+
"g": g.astype(np.float32),
|
| 84 |
+
"vqidx": vqidx.astype(np.int64),
|
| 85 |
+
"sid": sid.astype(np.int64)
|
| 86 |
+
},
|
| 87 |
+
)
|
| 88 |
+
x, m_p, logs_p, x_mask = enc_rtn[0], enc_rtn[1], enc_rtn[2], enc_rtn[3]
|
| 89 |
+
np.random.seed(seed)
|
| 90 |
+
zinput = np.random.randn(x.shape[0], 2, x.shape[2]) * sdp_noise_scale
|
| 91 |
+
logw = self.sdp.run(
|
| 92 |
+
None, {"x": x, "x_mask": x_mask, "zin": zinput.astype(np.float32), "g": g}
|
| 93 |
+
)[0] * (sdp_ratio) + self.dp.run(None, {"x": x, "x_mask": x_mask, "g": g})[
|
| 94 |
+
0
|
| 95 |
+
] * (
|
| 96 |
+
1 - sdp_ratio
|
| 97 |
+
)
|
| 98 |
+
w = np.exp(logw) * x_mask * length_scale
|
| 99 |
+
w_ceil = np.ceil(w)
|
| 100 |
+
y_lengths = np.clip(np.sum(w_ceil, (1, 2)), a_min=1.0, a_max=100000).astype(
|
| 101 |
+
np.int64
|
| 102 |
+
)
|
| 103 |
+
y_mask = np.expand_dims(sequence_mask(y_lengths, None), 1)
|
| 104 |
+
attn_mask = np.expand_dims(x_mask, 2) * np.expand_dims(y_mask, -1)
|
| 105 |
+
attn = generate_path(w_ceil, attn_mask)
|
| 106 |
+
m_p = np.matmul(attn.squeeze(1), m_p.transpose(0, 2, 1)).transpose(
|
| 107 |
+
0, 2, 1
|
| 108 |
+
) # [b, t', t], [b, t, d] -> [b, d, t']
|
| 109 |
+
logs_p = np.matmul(attn.squeeze(1), logs_p.transpose(0, 2, 1)).transpose(
|
| 110 |
+
0, 2, 1
|
| 111 |
+
) # [b, t', t], [b, t, d] -> [b, d, t']
|
| 112 |
+
|
| 113 |
+
z_p = (
|
| 114 |
+
m_p
|
| 115 |
+
+ np.random.randn(m_p.shape[0], m_p.shape[1], m_p.shape[2])
|
| 116 |
+
* np.exp(logs_p)
|
| 117 |
+
* seq_noise_scale
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
z = self.flow.run(
|
| 121 |
+
None,
|
| 122 |
+
{
|
| 123 |
+
"z_p": z_p.astype(np.float32),
|
| 124 |
+
"y_mask": y_mask.astype(np.float32),
|
| 125 |
+
"g": g,
|
| 126 |
+
},
|
| 127 |
+
)[0]
|
| 128 |
+
|
| 129 |
+
return self.dec.run(None, {"z_in": z.astype(np.float32), "g": g})[0]
|
onnx_modules/V220/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .text.symbols import symbols
|
| 2 |
+
from .models_onnx import SynthesizerTrn
|
| 3 |
+
|
| 4 |
+
__all__ = ["symbols", "SynthesizerTrn"]
|
onnx_modules/V220/attentions_onnx.py
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
from torch.nn import functional as F
|
| 5 |
+
|
| 6 |
+
import commons
|
| 7 |
+
import logging
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class LayerNorm(nn.Module):
|
| 13 |
+
def __init__(self, channels, eps=1e-5):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.channels = channels
|
| 16 |
+
self.eps = eps
|
| 17 |
+
|
| 18 |
+
self.gamma = nn.Parameter(torch.ones(channels))
|
| 19 |
+
self.beta = nn.Parameter(torch.zeros(channels))
|
| 20 |
+
|
| 21 |
+
def forward(self, x):
|
| 22 |
+
x = x.transpose(1, -1)
|
| 23 |
+
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
| 24 |
+
return x.transpose(1, -1)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@torch.jit.script
|
| 28 |
+
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
| 29 |
+
n_channels_int = n_channels[0]
|
| 30 |
+
in_act = input_a + input_b
|
| 31 |
+
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
| 32 |
+
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
| 33 |
+
acts = t_act * s_act
|
| 34 |
+
return acts
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class Encoder(nn.Module):
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
hidden_channels,
|
| 41 |
+
filter_channels,
|
| 42 |
+
n_heads,
|
| 43 |
+
n_layers,
|
| 44 |
+
kernel_size=1,
|
| 45 |
+
p_dropout=0.0,
|
| 46 |
+
window_size=4,
|
| 47 |
+
isflow=True,
|
| 48 |
+
**kwargs
|
| 49 |
+
):
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.hidden_channels = hidden_channels
|
| 52 |
+
self.filter_channels = filter_channels
|
| 53 |
+
self.n_heads = n_heads
|
| 54 |
+
self.n_layers = n_layers
|
| 55 |
+
self.kernel_size = kernel_size
|
| 56 |
+
self.p_dropout = p_dropout
|
| 57 |
+
self.window_size = window_size
|
| 58 |
+
# if isflow:
|
| 59 |
+
# cond_layer = torch.nn.Conv1d(256, 2*hidden_channels*n_layers, 1)
|
| 60 |
+
# self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1)
|
| 61 |
+
# self.cond_layer = weight_norm(cond_layer, name='weight')
|
| 62 |
+
# self.gin_channels = 256
|
| 63 |
+
self.cond_layer_idx = self.n_layers
|
| 64 |
+
if "gin_channels" in kwargs:
|
| 65 |
+
self.gin_channels = kwargs["gin_channels"]
|
| 66 |
+
if self.gin_channels != 0:
|
| 67 |
+
self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
|
| 68 |
+
# vits2 says 3rd block, so idx is 2 by default
|
| 69 |
+
self.cond_layer_idx = (
|
| 70 |
+
kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
|
| 71 |
+
)
|
| 72 |
+
logging.debug(self.gin_channels, self.cond_layer_idx)
|
| 73 |
+
assert (
|
| 74 |
+
self.cond_layer_idx < self.n_layers
|
| 75 |
+
), "cond_layer_idx should be less than n_layers"
|
| 76 |
+
self.drop = nn.Dropout(p_dropout)
|
| 77 |
+
self.attn_layers = nn.ModuleList()
|
| 78 |
+
self.norm_layers_1 = nn.ModuleList()
|
| 79 |
+
self.ffn_layers = nn.ModuleList()
|
| 80 |
+
self.norm_layers_2 = nn.ModuleList()
|
| 81 |
+
for i in range(self.n_layers):
|
| 82 |
+
self.attn_layers.append(
|
| 83 |
+
MultiHeadAttention(
|
| 84 |
+
hidden_channels,
|
| 85 |
+
hidden_channels,
|
| 86 |
+
n_heads,
|
| 87 |
+
p_dropout=p_dropout,
|
| 88 |
+
window_size=window_size,
|
| 89 |
+
)
|
| 90 |
+
)
|
| 91 |
+
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
| 92 |
+
self.ffn_layers.append(
|
| 93 |
+
FFN(
|
| 94 |
+
hidden_channels,
|
| 95 |
+
hidden_channels,
|
| 96 |
+
filter_channels,
|
| 97 |
+
kernel_size,
|
| 98 |
+
p_dropout=p_dropout,
|
| 99 |
+
)
|
| 100 |
+
)
|
| 101 |
+
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
| 102 |
+
|
| 103 |
+
def forward(self, x, x_mask, g=None):
|
| 104 |
+
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
| 105 |
+
x = x * x_mask
|
| 106 |
+
for i in range(self.n_layers):
|
| 107 |
+
if i == self.cond_layer_idx and g is not None:
|
| 108 |
+
g = self.spk_emb_linear(g.transpose(1, 2))
|
| 109 |
+
g = g.transpose(1, 2)
|
| 110 |
+
x = x + g
|
| 111 |
+
x = x * x_mask
|
| 112 |
+
y = self.attn_layers[i](x, x, attn_mask)
|
| 113 |
+
y = self.drop(y)
|
| 114 |
+
x = self.norm_layers_1[i](x + y)
|
| 115 |
+
|
| 116 |
+
y = self.ffn_layers[i](x, x_mask)
|
| 117 |
+
y = self.drop(y)
|
| 118 |
+
x = self.norm_layers_2[i](x + y)
|
| 119 |
+
x = x * x_mask
|
| 120 |
+
return x
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class MultiHeadAttention(nn.Module):
|
| 124 |
+
def __init__(
|
| 125 |
+
self,
|
| 126 |
+
channels,
|
| 127 |
+
out_channels,
|
| 128 |
+
n_heads,
|
| 129 |
+
p_dropout=0.0,
|
| 130 |
+
window_size=None,
|
| 131 |
+
heads_share=True,
|
| 132 |
+
block_length=None,
|
| 133 |
+
proximal_bias=False,
|
| 134 |
+
proximal_init=False,
|
| 135 |
+
):
|
| 136 |
+
super().__init__()
|
| 137 |
+
assert channels % n_heads == 0
|
| 138 |
+
|
| 139 |
+
self.channels = channels
|
| 140 |
+
self.out_channels = out_channels
|
| 141 |
+
self.n_heads = n_heads
|
| 142 |
+
self.p_dropout = p_dropout
|
| 143 |
+
self.window_size = window_size
|
| 144 |
+
self.heads_share = heads_share
|
| 145 |
+
self.block_length = block_length
|
| 146 |
+
self.proximal_bias = proximal_bias
|
| 147 |
+
self.proximal_init = proximal_init
|
| 148 |
+
self.attn = None
|
| 149 |
+
|
| 150 |
+
self.k_channels = channels // n_heads
|
| 151 |
+
self.conv_q = nn.Conv1d(channels, channels, 1)
|
| 152 |
+
self.conv_k = nn.Conv1d(channels, channels, 1)
|
| 153 |
+
self.conv_v = nn.Conv1d(channels, channels, 1)
|
| 154 |
+
self.conv_o = nn.Conv1d(channels, out_channels, 1)
|
| 155 |
+
self.drop = nn.Dropout(p_dropout)
|
| 156 |
+
|
| 157 |
+
if window_size is not None:
|
| 158 |
+
n_heads_rel = 1 if heads_share else n_heads
|
| 159 |
+
rel_stddev = self.k_channels**-0.5
|
| 160 |
+
self.emb_rel_k = nn.Parameter(
|
| 161 |
+
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
| 162 |
+
* rel_stddev
|
| 163 |
+
)
|
| 164 |
+
self.emb_rel_v = nn.Parameter(
|
| 165 |
+
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
| 166 |
+
* rel_stddev
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
nn.init.xavier_uniform_(self.conv_q.weight)
|
| 170 |
+
nn.init.xavier_uniform_(self.conv_k.weight)
|
| 171 |
+
nn.init.xavier_uniform_(self.conv_v.weight)
|
| 172 |
+
if proximal_init:
|
| 173 |
+
with torch.no_grad():
|
| 174 |
+
self.conv_k.weight.copy_(self.conv_q.weight)
|
| 175 |
+
self.conv_k.bias.copy_(self.conv_q.bias)
|
| 176 |
+
|
| 177 |
+
def forward(self, x, c, attn_mask=None):
|
| 178 |
+
q = self.conv_q(x)
|
| 179 |
+
k = self.conv_k(c)
|
| 180 |
+
v = self.conv_v(c)
|
| 181 |
+
|
| 182 |
+
x, self.attn = self.attention(q, k, v, mask=attn_mask)
|
| 183 |
+
|
| 184 |
+
x = self.conv_o(x)
|
| 185 |
+
return x
|
| 186 |
+
|
| 187 |
+
def attention(self, query, key, value, mask=None):
|
| 188 |
+
# reshape [b, d, t] -> [b, n_h, t, d_k]
|
| 189 |
+
b, d, t_s, t_t = (*key.size(), query.size(2))
|
| 190 |
+
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
|
| 191 |
+
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
| 192 |
+
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
| 193 |
+
|
| 194 |
+
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
|
| 195 |
+
if self.window_size is not None:
|
| 196 |
+
assert (
|
| 197 |
+
t_s == t_t
|
| 198 |
+
), "Relative attention is only available for self-attention."
|
| 199 |
+
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
|
| 200 |
+
rel_logits = self._matmul_with_relative_keys(
|
| 201 |
+
query / math.sqrt(self.k_channels), key_relative_embeddings
|
| 202 |
+
)
|
| 203 |
+
scores_local = self._relative_position_to_absolute_position(rel_logits)
|
| 204 |
+
scores = scores + scores_local
|
| 205 |
+
if self.proximal_bias:
|
| 206 |
+
assert t_s == t_t, "Proximal bias is only available for self-attention."
|
| 207 |
+
scores = scores + self._attention_bias_proximal(t_s).to(
|
| 208 |
+
device=scores.device, dtype=scores.dtype
|
| 209 |
+
)
|
| 210 |
+
if mask is not None:
|
| 211 |
+
scores = scores.masked_fill(mask == 0, -1e4)
|
| 212 |
+
if self.block_length is not None:
|
| 213 |
+
assert (
|
| 214 |
+
t_s == t_t
|
| 215 |
+
), "Local attention is only available for self-attention."
|
| 216 |
+
block_mask = (
|
| 217 |
+
torch.ones_like(scores)
|
| 218 |
+
.triu(-self.block_length)
|
| 219 |
+
.tril(self.block_length)
|
| 220 |
+
)
|
| 221 |
+
scores = scores.masked_fill(block_mask == 0, -1e4)
|
| 222 |
+
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
|
| 223 |
+
p_attn = self.drop(p_attn)
|
| 224 |
+
output = torch.matmul(p_attn, value)
|
| 225 |
+
if self.window_size is not None:
|
| 226 |
+
relative_weights = self._absolute_position_to_relative_position(p_attn)
|
| 227 |
+
value_relative_embeddings = self._get_relative_embeddings(
|
| 228 |
+
self.emb_rel_v, t_s
|
| 229 |
+
)
|
| 230 |
+
output = output + self._matmul_with_relative_values(
|
| 231 |
+
relative_weights, value_relative_embeddings
|
| 232 |
+
)
|
| 233 |
+
output = (
|
| 234 |
+
output.transpose(2, 3).contiguous().view(b, d, t_t)
|
| 235 |
+
) # [b, n_h, t_t, d_k] -> [b, d, t_t]
|
| 236 |
+
return output, p_attn
|
| 237 |
+
|
| 238 |
+
def _matmul_with_relative_values(self, x, y):
|
| 239 |
+
"""
|
| 240 |
+
x: [b, h, l, m]
|
| 241 |
+
y: [h or 1, m, d]
|
| 242 |
+
ret: [b, h, l, d]
|
| 243 |
+
"""
|
| 244 |
+
ret = torch.matmul(x, y.unsqueeze(0))
|
| 245 |
+
return ret
|
| 246 |
+
|
| 247 |
+
def _matmul_with_relative_keys(self, x, y):
|
| 248 |
+
"""
|
| 249 |
+
x: [b, h, l, d]
|
| 250 |
+
y: [h or 1, m, d]
|
| 251 |
+
ret: [b, h, l, m]
|
| 252 |
+
"""
|
| 253 |
+
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
|
| 254 |
+
return ret
|
| 255 |
+
|
| 256 |
+
def _get_relative_embeddings(self, relative_embeddings, length):
|
| 257 |
+
max_relative_position = 2 * self.window_size + 1
|
| 258 |
+
# Pad first before slice to avoid using cond ops.
|
| 259 |
+
pad_length = max(length - (self.window_size + 1), 0)
|
| 260 |
+
slice_start_position = max((self.window_size + 1) - length, 0)
|
| 261 |
+
slice_end_position = slice_start_position + 2 * length - 1
|
| 262 |
+
if pad_length > 0:
|
| 263 |
+
padded_relative_embeddings = F.pad(
|
| 264 |
+
relative_embeddings,
|
| 265 |
+
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
|
| 266 |
+
)
|
| 267 |
+
else:
|
| 268 |
+
padded_relative_embeddings = relative_embeddings
|
| 269 |
+
used_relative_embeddings = padded_relative_embeddings[
|
| 270 |
+
:, slice_start_position:slice_end_position
|
| 271 |
+
]
|
| 272 |
+
return used_relative_embeddings
|
| 273 |
+
|
| 274 |
+
def _relative_position_to_absolute_position(self, x):
|
| 275 |
+
"""
|
| 276 |
+
x: [b, h, l, 2*l-1]
|
| 277 |
+
ret: [b, h, l, l]
|
| 278 |
+
"""
|
| 279 |
+
batch, heads, length, _ = x.size()
|
| 280 |
+
# Concat columns of pad to shift from relative to absolute indexing.
|
| 281 |
+
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
|
| 282 |
+
|
| 283 |
+
# Concat extra elements so to add up to shape (len+1, 2*len-1).
|
| 284 |
+
x_flat = x.view([batch, heads, length * 2 * length])
|
| 285 |
+
x_flat = F.pad(
|
| 286 |
+
x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
# Reshape and slice out the padded elements.
|
| 290 |
+
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
|
| 291 |
+
:, :, :length, length - 1 :
|
| 292 |
+
]
|
| 293 |
+
return x_final
|
| 294 |
+
|
| 295 |
+
def _absolute_position_to_relative_position(self, x):
|
| 296 |
+
"""
|
| 297 |
+
x: [b, h, l, l]
|
| 298 |
+
ret: [b, h, l, 2*l-1]
|
| 299 |
+
"""
|
| 300 |
+
batch, heads, length, _ = x.size()
|
| 301 |
+
# padd along column
|
| 302 |
+
x = F.pad(
|
| 303 |
+
x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
|
| 304 |
+
)
|
| 305 |
+
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
|
| 306 |
+
# add 0's in the beginning that will skew the elements after reshape
|
| 307 |
+
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
|
| 308 |
+
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
|
| 309 |
+
return x_final
|
| 310 |
+
|
| 311 |
+
def _attention_bias_proximal(self, length):
|
| 312 |
+
"""Bias for self-attention to encourage attention to close positions.
|
| 313 |
+
Args:
|
| 314 |
+
length: an integer scalar.
|
| 315 |
+
Returns:
|
| 316 |
+
a Tensor with shape [1, 1, length, length]
|
| 317 |
+
"""
|
| 318 |
+
r = torch.arange(length, dtype=torch.float32)
|
| 319 |
+
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
|
| 320 |
+
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
class FFN(nn.Module):
|
| 324 |
+
def __init__(
|
| 325 |
+
self,
|
| 326 |
+
in_channels,
|
| 327 |
+
out_channels,
|
| 328 |
+
filter_channels,
|
| 329 |
+
kernel_size,
|
| 330 |
+
p_dropout=0.0,
|
| 331 |
+
activation=None,
|
| 332 |
+
causal=False,
|
| 333 |
+
):
|
| 334 |
+
super().__init__()
|
| 335 |
+
self.in_channels = in_channels
|
| 336 |
+
self.out_channels = out_channels
|
| 337 |
+
self.filter_channels = filter_channels
|
| 338 |
+
self.kernel_size = kernel_size
|
| 339 |
+
self.p_dropout = p_dropout
|
| 340 |
+
self.activation = activation
|
| 341 |
+
self.causal = causal
|
| 342 |
+
|
| 343 |
+
if causal:
|
| 344 |
+
self.padding = self._causal_padding
|
| 345 |
+
else:
|
| 346 |
+
self.padding = self._same_padding
|
| 347 |
+
|
| 348 |
+
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
|
| 349 |
+
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
|
| 350 |
+
self.drop = nn.Dropout(p_dropout)
|
| 351 |
+
|
| 352 |
+
def forward(self, x, x_mask):
|
| 353 |
+
x = self.conv_1(self.padding(x * x_mask))
|
| 354 |
+
if self.activation == "gelu":
|
| 355 |
+
x = x * torch.sigmoid(1.702 * x)
|
| 356 |
+
else:
|
| 357 |
+
x = torch.relu(x)
|
| 358 |
+
x = self.drop(x)
|
| 359 |
+
x = self.conv_2(self.padding(x * x_mask))
|
| 360 |
+
return x * x_mask
|
| 361 |
+
|
| 362 |
+
def _causal_padding(self, x):
|
| 363 |
+
if self.kernel_size == 1:
|
| 364 |
+
return x
|
| 365 |
+
pad_l = self.kernel_size - 1
|
| 366 |
+
pad_r = 0
|
| 367 |
+
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
| 368 |
+
x = F.pad(x, commons.convert_pad_shape(padding))
|
| 369 |
+
return x
|
| 370 |
+
|
| 371 |
+
def _same_padding(self, x):
|
| 372 |
+
if self.kernel_size == 1:
|
| 373 |
+
return x
|
| 374 |
+
pad_l = (self.kernel_size - 1) // 2
|
| 375 |
+
pad_r = self.kernel_size // 2
|
| 376 |
+
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
| 377 |
+
x = F.pad(x, commons.convert_pad_shape(padding))
|
| 378 |
+
return x
|
onnx_modules/V220/models_onnx.py
ADDED
|
@@ -0,0 +1,1076 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
from torch.nn import functional as F
|
| 5 |
+
|
| 6 |
+
import commons
|
| 7 |
+
import modules
|
| 8 |
+
from . import attentions_onnx
|
| 9 |
+
from vector_quantize_pytorch import VectorQuantize
|
| 10 |
+
|
| 11 |
+
from torch.nn import Conv1d, ConvTranspose1d, Conv2d
|
| 12 |
+
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
| 13 |
+
from commons import init_weights, get_padding
|
| 14 |
+
from .text import symbols, num_tones, num_languages
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class DurationDiscriminator(nn.Module): # vits2
|
| 18 |
+
def __init__(
|
| 19 |
+
self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
|
| 20 |
+
):
|
| 21 |
+
super().__init__()
|
| 22 |
+
|
| 23 |
+
self.in_channels = in_channels
|
| 24 |
+
self.filter_channels = filter_channels
|
| 25 |
+
self.kernel_size = kernel_size
|
| 26 |
+
self.p_dropout = p_dropout
|
| 27 |
+
self.gin_channels = gin_channels
|
| 28 |
+
|
| 29 |
+
self.drop = nn.Dropout(p_dropout)
|
| 30 |
+
self.conv_1 = nn.Conv1d(
|
| 31 |
+
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
| 32 |
+
)
|
| 33 |
+
self.norm_1 = modules.LayerNorm(filter_channels)
|
| 34 |
+
self.conv_2 = nn.Conv1d(
|
| 35 |
+
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
| 36 |
+
)
|
| 37 |
+
self.norm_2 = modules.LayerNorm(filter_channels)
|
| 38 |
+
self.dur_proj = nn.Conv1d(1, filter_channels, 1)
|
| 39 |
+
|
| 40 |
+
self.pre_out_conv_1 = nn.Conv1d(
|
| 41 |
+
2 * filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
| 42 |
+
)
|
| 43 |
+
self.pre_out_norm_1 = modules.LayerNorm(filter_channels)
|
| 44 |
+
self.pre_out_conv_2 = nn.Conv1d(
|
| 45 |
+
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
| 46 |
+
)
|
| 47 |
+
self.pre_out_norm_2 = modules.LayerNorm(filter_channels)
|
| 48 |
+
|
| 49 |
+
if gin_channels != 0:
|
| 50 |
+
self.cond = nn.Conv1d(gin_channels, in_channels, 1)
|
| 51 |
+
|
| 52 |
+
self.output_layer = nn.Sequential(nn.Linear(filter_channels, 1), nn.Sigmoid())
|
| 53 |
+
|
| 54 |
+
def forward_probability(self, x, x_mask, dur, g=None):
|
| 55 |
+
dur = self.dur_proj(dur)
|
| 56 |
+
x = torch.cat([x, dur], dim=1)
|
| 57 |
+
x = self.pre_out_conv_1(x * x_mask)
|
| 58 |
+
x = torch.relu(x)
|
| 59 |
+
x = self.pre_out_norm_1(x)
|
| 60 |
+
x = self.drop(x)
|
| 61 |
+
x = self.pre_out_conv_2(x * x_mask)
|
| 62 |
+
x = torch.relu(x)
|
| 63 |
+
x = self.pre_out_norm_2(x)
|
| 64 |
+
x = self.drop(x)
|
| 65 |
+
x = x * x_mask
|
| 66 |
+
x = x.transpose(1, 2)
|
| 67 |
+
output_prob = self.output_layer(x)
|
| 68 |
+
return output_prob
|
| 69 |
+
|
| 70 |
+
def forward(self, x, x_mask, dur_r, dur_hat, g=None):
|
| 71 |
+
x = torch.detach(x)
|
| 72 |
+
if g is not None:
|
| 73 |
+
g = torch.detach(g)
|
| 74 |
+
x = x + self.cond(g)
|
| 75 |
+
x = self.conv_1(x * x_mask)
|
| 76 |
+
x = torch.relu(x)
|
| 77 |
+
x = self.norm_1(x)
|
| 78 |
+
x = self.drop(x)
|
| 79 |
+
x = self.conv_2(x * x_mask)
|
| 80 |
+
x = torch.relu(x)
|
| 81 |
+
x = self.norm_2(x)
|
| 82 |
+
x = self.drop(x)
|
| 83 |
+
|
| 84 |
+
output_probs = []
|
| 85 |
+
for dur in [dur_r, dur_hat]:
|
| 86 |
+
output_prob = self.forward_probability(x, x_mask, dur, g)
|
| 87 |
+
output_probs.append(output_prob)
|
| 88 |
+
|
| 89 |
+
return output_probs
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class TransformerCouplingBlock(nn.Module):
|
| 93 |
+
def __init__(
|
| 94 |
+
self,
|
| 95 |
+
channels,
|
| 96 |
+
hidden_channels,
|
| 97 |
+
filter_channels,
|
| 98 |
+
n_heads,
|
| 99 |
+
n_layers,
|
| 100 |
+
kernel_size,
|
| 101 |
+
p_dropout,
|
| 102 |
+
n_flows=4,
|
| 103 |
+
gin_channels=0,
|
| 104 |
+
share_parameter=False,
|
| 105 |
+
):
|
| 106 |
+
super().__init__()
|
| 107 |
+
self.channels = channels
|
| 108 |
+
self.hidden_channels = hidden_channels
|
| 109 |
+
self.kernel_size = kernel_size
|
| 110 |
+
self.n_layers = n_layers
|
| 111 |
+
self.n_flows = n_flows
|
| 112 |
+
self.gin_channels = gin_channels
|
| 113 |
+
|
| 114 |
+
self.flows = nn.ModuleList()
|
| 115 |
+
|
| 116 |
+
self.wn = (
|
| 117 |
+
attentions_onnx.FFT(
|
| 118 |
+
hidden_channels,
|
| 119 |
+
filter_channels,
|
| 120 |
+
n_heads,
|
| 121 |
+
n_layers,
|
| 122 |
+
kernel_size,
|
| 123 |
+
p_dropout,
|
| 124 |
+
isflow=True,
|
| 125 |
+
gin_channels=self.gin_channels,
|
| 126 |
+
)
|
| 127 |
+
if share_parameter
|
| 128 |
+
else None
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
for i in range(n_flows):
|
| 132 |
+
self.flows.append(
|
| 133 |
+
modules.TransformerCouplingLayer(
|
| 134 |
+
channels,
|
| 135 |
+
hidden_channels,
|
| 136 |
+
kernel_size,
|
| 137 |
+
n_layers,
|
| 138 |
+
n_heads,
|
| 139 |
+
p_dropout,
|
| 140 |
+
filter_channels,
|
| 141 |
+
mean_only=True,
|
| 142 |
+
wn_sharing_parameter=self.wn,
|
| 143 |
+
gin_channels=self.gin_channels,
|
| 144 |
+
)
|
| 145 |
+
)
|
| 146 |
+
self.flows.append(modules.Flip())
|
| 147 |
+
|
| 148 |
+
def forward(self, x, x_mask, g=None, reverse=True):
|
| 149 |
+
if not reverse:
|
| 150 |
+
for flow in self.flows:
|
| 151 |
+
x, _ = flow(x, x_mask, g=g, reverse=reverse)
|
| 152 |
+
else:
|
| 153 |
+
for flow in reversed(self.flows):
|
| 154 |
+
x = flow(x, x_mask, g=g, reverse=reverse)
|
| 155 |
+
return x
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class StochasticDurationPredictor(nn.Module):
|
| 159 |
+
def __init__(
|
| 160 |
+
self,
|
| 161 |
+
in_channels,
|
| 162 |
+
filter_channels,
|
| 163 |
+
kernel_size,
|
| 164 |
+
p_dropout,
|
| 165 |
+
n_flows=4,
|
| 166 |
+
gin_channels=0,
|
| 167 |
+
):
|
| 168 |
+
super().__init__()
|
| 169 |
+
filter_channels = in_channels # it needs to be removed from future version.
|
| 170 |
+
self.in_channels = in_channels
|
| 171 |
+
self.filter_channels = filter_channels
|
| 172 |
+
self.kernel_size = kernel_size
|
| 173 |
+
self.p_dropout = p_dropout
|
| 174 |
+
self.n_flows = n_flows
|
| 175 |
+
self.gin_channels = gin_channels
|
| 176 |
+
|
| 177 |
+
self.log_flow = modules.Log()
|
| 178 |
+
self.flows = nn.ModuleList()
|
| 179 |
+
self.flows.append(modules.ElementwiseAffine(2))
|
| 180 |
+
for i in range(n_flows):
|
| 181 |
+
self.flows.append(
|
| 182 |
+
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
|
| 183 |
+
)
|
| 184 |
+
self.flows.append(modules.Flip())
|
| 185 |
+
|
| 186 |
+
self.post_pre = nn.Conv1d(1, filter_channels, 1)
|
| 187 |
+
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
| 188 |
+
self.post_convs = modules.DDSConv(
|
| 189 |
+
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
|
| 190 |
+
)
|
| 191 |
+
self.post_flows = nn.ModuleList()
|
| 192 |
+
self.post_flows.append(modules.ElementwiseAffine(2))
|
| 193 |
+
for i in range(4):
|
| 194 |
+
self.post_flows.append(
|
| 195 |
+
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
|
| 196 |
+
)
|
| 197 |
+
self.post_flows.append(modules.Flip())
|
| 198 |
+
|
| 199 |
+
self.pre = nn.Conv1d(in_channels, filter_channels, 1)
|
| 200 |
+
self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
| 201 |
+
self.convs = modules.DDSConv(
|
| 202 |
+
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
|
| 203 |
+
)
|
| 204 |
+
if gin_channels != 0:
|
| 205 |
+
self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
|
| 206 |
+
|
| 207 |
+
def forward(self, x, x_mask, z, g=None):
|
| 208 |
+
x = torch.detach(x)
|
| 209 |
+
x = self.pre(x)
|
| 210 |
+
if g is not None:
|
| 211 |
+
g = torch.detach(g)
|
| 212 |
+
x = x + self.cond(g)
|
| 213 |
+
x = self.convs(x, x_mask)
|
| 214 |
+
x = self.proj(x) * x_mask
|
| 215 |
+
|
| 216 |
+
flows = list(reversed(self.flows))
|
| 217 |
+
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
|
| 218 |
+
for flow in flows:
|
| 219 |
+
z = flow(z, x_mask, g=x, reverse=True)
|
| 220 |
+
z0, z1 = torch.split(z, [1, 1], 1)
|
| 221 |
+
logw = z0
|
| 222 |
+
return logw
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class DurationPredictor(nn.Module):
|
| 226 |
+
def __init__(
|
| 227 |
+
self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
|
| 228 |
+
):
|
| 229 |
+
super().__init__()
|
| 230 |
+
|
| 231 |
+
self.in_channels = in_channels
|
| 232 |
+
self.filter_channels = filter_channels
|
| 233 |
+
self.kernel_size = kernel_size
|
| 234 |
+
self.p_dropout = p_dropout
|
| 235 |
+
self.gin_channels = gin_channels
|
| 236 |
+
|
| 237 |
+
self.drop = nn.Dropout(p_dropout)
|
| 238 |
+
self.conv_1 = nn.Conv1d(
|
| 239 |
+
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
| 240 |
+
)
|
| 241 |
+
self.norm_1 = modules.LayerNorm(filter_channels)
|
| 242 |
+
self.conv_2 = nn.Conv1d(
|
| 243 |
+
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
| 244 |
+
)
|
| 245 |
+
self.norm_2 = modules.LayerNorm(filter_channels)
|
| 246 |
+
self.proj = nn.Conv1d(filter_channels, 1, 1)
|
| 247 |
+
|
| 248 |
+
if gin_channels != 0:
|
| 249 |
+
self.cond = nn.Conv1d(gin_channels, in_channels, 1)
|
| 250 |
+
|
| 251 |
+
def forward(self, x, x_mask, g=None):
|
| 252 |
+
x = torch.detach(x)
|
| 253 |
+
if g is not None:
|
| 254 |
+
g = torch.detach(g)
|
| 255 |
+
x = x + self.cond(g)
|
| 256 |
+
x = self.conv_1(x * x_mask)
|
| 257 |
+
x = torch.relu(x)
|
| 258 |
+
x = self.norm_1(x)
|
| 259 |
+
x = self.drop(x)
|
| 260 |
+
x = self.conv_2(x * x_mask)
|
| 261 |
+
x = torch.relu(x)
|
| 262 |
+
x = self.norm_2(x)
|
| 263 |
+
x = self.drop(x)
|
| 264 |
+
x = self.proj(x * x_mask)
|
| 265 |
+
return x * x_mask
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
class Bottleneck(nn.Sequential):
|
| 269 |
+
def __init__(self, in_dim, hidden_dim):
|
| 270 |
+
c_fc1 = nn.Linear(in_dim, hidden_dim, bias=False)
|
| 271 |
+
c_fc2 = nn.Linear(in_dim, hidden_dim, bias=False)
|
| 272 |
+
super().__init__(*[c_fc1, c_fc2])
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
class Block(nn.Module):
|
| 276 |
+
def __init__(self, in_dim, hidden_dim) -> None:
|
| 277 |
+
super().__init__()
|
| 278 |
+
self.norm = nn.LayerNorm(in_dim)
|
| 279 |
+
self.mlp = MLP(in_dim, hidden_dim)
|
| 280 |
+
|
| 281 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 282 |
+
x = x + self.mlp(self.norm(x))
|
| 283 |
+
return x
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
class MLP(nn.Module):
|
| 287 |
+
def __init__(self, in_dim, hidden_dim):
|
| 288 |
+
super().__init__()
|
| 289 |
+
self.c_fc1 = nn.Linear(in_dim, hidden_dim, bias=False)
|
| 290 |
+
self.c_fc2 = nn.Linear(in_dim, hidden_dim, bias=False)
|
| 291 |
+
self.c_proj = nn.Linear(hidden_dim, in_dim, bias=False)
|
| 292 |
+
|
| 293 |
+
def forward(self, x: torch.Tensor):
|
| 294 |
+
x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
|
| 295 |
+
x = self.c_proj(x)
|
| 296 |
+
return x
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
class TextEncoder(nn.Module):
|
| 300 |
+
def __init__(
|
| 301 |
+
self,
|
| 302 |
+
n_vocab,
|
| 303 |
+
out_channels,
|
| 304 |
+
hidden_channels,
|
| 305 |
+
filter_channels,
|
| 306 |
+
n_heads,
|
| 307 |
+
n_layers,
|
| 308 |
+
kernel_size,
|
| 309 |
+
p_dropout,
|
| 310 |
+
n_speakers,
|
| 311 |
+
gin_channels=0,
|
| 312 |
+
):
|
| 313 |
+
super().__init__()
|
| 314 |
+
self.n_vocab = n_vocab
|
| 315 |
+
self.out_channels = out_channels
|
| 316 |
+
self.hidden_channels = hidden_channels
|
| 317 |
+
self.filter_channels = filter_channels
|
| 318 |
+
self.n_heads = n_heads
|
| 319 |
+
self.n_layers = n_layers
|
| 320 |
+
self.kernel_size = kernel_size
|
| 321 |
+
self.p_dropout = p_dropout
|
| 322 |
+
self.gin_channels = gin_channels
|
| 323 |
+
self.emb = nn.Embedding(len(symbols), hidden_channels)
|
| 324 |
+
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
|
| 325 |
+
self.tone_emb = nn.Embedding(num_tones, hidden_channels)
|
| 326 |
+
nn.init.normal_(self.tone_emb.weight, 0.0, hidden_channels**-0.5)
|
| 327 |
+
self.language_emb = nn.Embedding(num_languages, hidden_channels)
|
| 328 |
+
nn.init.normal_(self.language_emb.weight, 0.0, hidden_channels**-0.5)
|
| 329 |
+
self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
|
| 330 |
+
self.ja_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
|
| 331 |
+
self.en_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
|
| 332 |
+
# self.emo_proj = nn.Linear(1024, 1024)
|
| 333 |
+
# self.emo_quantizer = nn.ModuleList()
|
| 334 |
+
# for i in range(0, n_speakers):
|
| 335 |
+
# self.emo_quantizer.append(
|
| 336 |
+
# VectorQuantize(
|
| 337 |
+
# dim=1024,
|
| 338 |
+
# codebook_size=10,
|
| 339 |
+
# decay=0.8,
|
| 340 |
+
# commitment_weight=1.0,
|
| 341 |
+
# learnable_codebook=True,
|
| 342 |
+
# ema_update=False,
|
| 343 |
+
# )
|
| 344 |
+
# )
|
| 345 |
+
# self.emo_q_proj = nn.Linear(1024, hidden_channels)
|
| 346 |
+
self.n_speakers = n_speakers
|
| 347 |
+
self.in_feature_net = nn.Sequential(
|
| 348 |
+
# input is assumed to an already normalized embedding
|
| 349 |
+
nn.Linear(512, 1028, bias=False),
|
| 350 |
+
nn.GELU(),
|
| 351 |
+
nn.LayerNorm(1028),
|
| 352 |
+
*[Block(1028, 512) for _ in range(1)],
|
| 353 |
+
nn.Linear(1028, 512, bias=False),
|
| 354 |
+
# normalize before passing to VQ?
|
| 355 |
+
# nn.GELU(),
|
| 356 |
+
# nn.LayerNorm(512),
|
| 357 |
+
)
|
| 358 |
+
self.emo_vq = VectorQuantize(
|
| 359 |
+
dim=512,
|
| 360 |
+
codebook_size=64,
|
| 361 |
+
codebook_dim=32,
|
| 362 |
+
commitment_weight=0.1,
|
| 363 |
+
decay=0.85,
|
| 364 |
+
heads=32,
|
| 365 |
+
kmeans_iters=20,
|
| 366 |
+
separate_codebook_per_head=True,
|
| 367 |
+
stochastic_sample_codes=True,
|
| 368 |
+
threshold_ema_dead_code=2,
|
| 369 |
+
)
|
| 370 |
+
self.out_feature_net = nn.Linear(512, hidden_channels)
|
| 371 |
+
|
| 372 |
+
self.encoder = attentions_onnx.Encoder(
|
| 373 |
+
hidden_channels,
|
| 374 |
+
filter_channels,
|
| 375 |
+
n_heads,
|
| 376 |
+
n_layers,
|
| 377 |
+
kernel_size,
|
| 378 |
+
p_dropout,
|
| 379 |
+
gin_channels=self.gin_channels,
|
| 380 |
+
)
|
| 381 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
| 382 |
+
|
| 383 |
+
def forward(
|
| 384 |
+
self, x, x_lengths, tone, language, bert, ja_bert, en_bert, emo, g=None
|
| 385 |
+
):
|
| 386 |
+
x_mask = torch.ones_like(x).unsqueeze(0)
|
| 387 |
+
bert_emb = self.bert_proj(bert.transpose(0, 1).unsqueeze(0)).transpose(1, 2)
|
| 388 |
+
ja_bert_emb = self.ja_bert_proj(ja_bert.transpose(0, 1).unsqueeze(0)).transpose(
|
| 389 |
+
1, 2
|
| 390 |
+
)
|
| 391 |
+
en_bert_emb = self.en_bert_proj(en_bert.transpose(0, 1).unsqueeze(0)).transpose(
|
| 392 |
+
1, 2
|
| 393 |
+
)
|
| 394 |
+
emo_emb = self.in_feature_net(emo.transpose(0, 1))
|
| 395 |
+
emo_emb, _, _ = self.emo_vq(emo_emb.unsqueeze(1))
|
| 396 |
+
|
| 397 |
+
emo_emb = self.out_feature_net(emo_emb)
|
| 398 |
+
|
| 399 |
+
x = (
|
| 400 |
+
self.emb(x)
|
| 401 |
+
+ self.tone_emb(tone)
|
| 402 |
+
+ self.language_emb(language)
|
| 403 |
+
+ bert_emb
|
| 404 |
+
+ ja_bert_emb
|
| 405 |
+
+ en_bert_emb
|
| 406 |
+
+ emo_emb
|
| 407 |
+
) * math.sqrt(
|
| 408 |
+
self.hidden_channels
|
| 409 |
+
) # [b, t, h]
|
| 410 |
+
x = torch.transpose(x, 1, -1) # [b, h, t]
|
| 411 |
+
x_mask = x_mask.to(x.dtype)
|
| 412 |
+
|
| 413 |
+
x = self.encoder(x * x_mask, x_mask, g=g)
|
| 414 |
+
stats = self.proj(x) * x_mask
|
| 415 |
+
|
| 416 |
+
m, logs = torch.split(stats, self.out_channels, dim=1)
|
| 417 |
+
return x, m, logs, x_mask
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
class ResidualCouplingBlock(nn.Module):
|
| 421 |
+
def __init__(
|
| 422 |
+
self,
|
| 423 |
+
channels,
|
| 424 |
+
hidden_channels,
|
| 425 |
+
kernel_size,
|
| 426 |
+
dilation_rate,
|
| 427 |
+
n_layers,
|
| 428 |
+
n_flows=4,
|
| 429 |
+
gin_channels=0,
|
| 430 |
+
):
|
| 431 |
+
super().__init__()
|
| 432 |
+
self.channels = channels
|
| 433 |
+
self.hidden_channels = hidden_channels
|
| 434 |
+
self.kernel_size = kernel_size
|
| 435 |
+
self.dilation_rate = dilation_rate
|
| 436 |
+
self.n_layers = n_layers
|
| 437 |
+
self.n_flows = n_flows
|
| 438 |
+
self.gin_channels = gin_channels
|
| 439 |
+
|
| 440 |
+
self.flows = nn.ModuleList()
|
| 441 |
+
for i in range(n_flows):
|
| 442 |
+
self.flows.append(
|
| 443 |
+
modules.ResidualCouplingLayer(
|
| 444 |
+
channels,
|
| 445 |
+
hidden_channels,
|
| 446 |
+
kernel_size,
|
| 447 |
+
dilation_rate,
|
| 448 |
+
n_layers,
|
| 449 |
+
gin_channels=gin_channels,
|
| 450 |
+
mean_only=True,
|
| 451 |
+
)
|
| 452 |
+
)
|
| 453 |
+
self.flows.append(modules.Flip())
|
| 454 |
+
|
| 455 |
+
def forward(self, x, x_mask, g=None, reverse=True):
|
| 456 |
+
if not reverse:
|
| 457 |
+
for flow in self.flows:
|
| 458 |
+
x, _ = flow(x, x_mask, g=g, reverse=reverse)
|
| 459 |
+
else:
|
| 460 |
+
for flow in reversed(self.flows):
|
| 461 |
+
x = flow(x, x_mask, g=g, reverse=reverse)
|
| 462 |
+
return x
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
class PosteriorEncoder(nn.Module):
|
| 466 |
+
def __init__(
|
| 467 |
+
self,
|
| 468 |
+
in_channels,
|
| 469 |
+
out_channels,
|
| 470 |
+
hidden_channels,
|
| 471 |
+
kernel_size,
|
| 472 |
+
dilation_rate,
|
| 473 |
+
n_layers,
|
| 474 |
+
gin_channels=0,
|
| 475 |
+
):
|
| 476 |
+
super().__init__()
|
| 477 |
+
self.in_channels = in_channels
|
| 478 |
+
self.out_channels = out_channels
|
| 479 |
+
self.hidden_channels = hidden_channels
|
| 480 |
+
self.kernel_size = kernel_size
|
| 481 |
+
self.dilation_rate = dilation_rate
|
| 482 |
+
self.n_layers = n_layers
|
| 483 |
+
self.gin_channels = gin_channels
|
| 484 |
+
|
| 485 |
+
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
| 486 |
+
self.enc = modules.WN(
|
| 487 |
+
hidden_channels,
|
| 488 |
+
kernel_size,
|
| 489 |
+
dilation_rate,
|
| 490 |
+
n_layers,
|
| 491 |
+
gin_channels=gin_channels,
|
| 492 |
+
)
|
| 493 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
| 494 |
+
|
| 495 |
+
def forward(self, x, x_lengths, g=None):
|
| 496 |
+
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
|
| 497 |
+
x.dtype
|
| 498 |
+
)
|
| 499 |
+
x = self.pre(x) * x_mask
|
| 500 |
+
x = self.enc(x, x_mask, g=g)
|
| 501 |
+
stats = self.proj(x) * x_mask
|
| 502 |
+
m, logs = torch.split(stats, self.out_channels, dim=1)
|
| 503 |
+
z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
|
| 504 |
+
return z, m, logs, x_mask
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
class Generator(torch.nn.Module):
|
| 508 |
+
def __init__(
|
| 509 |
+
self,
|
| 510 |
+
initial_channel,
|
| 511 |
+
resblock,
|
| 512 |
+
resblock_kernel_sizes,
|
| 513 |
+
resblock_dilation_sizes,
|
| 514 |
+
upsample_rates,
|
| 515 |
+
upsample_initial_channel,
|
| 516 |
+
upsample_kernel_sizes,
|
| 517 |
+
gin_channels=0,
|
| 518 |
+
):
|
| 519 |
+
super(Generator, self).__init__()
|
| 520 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
| 521 |
+
self.num_upsamples = len(upsample_rates)
|
| 522 |
+
self.conv_pre = Conv1d(
|
| 523 |
+
initial_channel, upsample_initial_channel, 7, 1, padding=3
|
| 524 |
+
)
|
| 525 |
+
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
|
| 526 |
+
|
| 527 |
+
self.ups = nn.ModuleList()
|
| 528 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
| 529 |
+
self.ups.append(
|
| 530 |
+
weight_norm(
|
| 531 |
+
ConvTranspose1d(
|
| 532 |
+
upsample_initial_channel // (2**i),
|
| 533 |
+
upsample_initial_channel // (2 ** (i + 1)),
|
| 534 |
+
k,
|
| 535 |
+
u,
|
| 536 |
+
padding=(k - u) // 2,
|
| 537 |
+
)
|
| 538 |
+
)
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
self.resblocks = nn.ModuleList()
|
| 542 |
+
for i in range(len(self.ups)):
|
| 543 |
+
ch = upsample_initial_channel // (2 ** (i + 1))
|
| 544 |
+
for j, (k, d) in enumerate(
|
| 545 |
+
zip(resblock_kernel_sizes, resblock_dilation_sizes)
|
| 546 |
+
):
|
| 547 |
+
self.resblocks.append(resblock(ch, k, d))
|
| 548 |
+
|
| 549 |
+
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
| 550 |
+
self.ups.apply(init_weights)
|
| 551 |
+
|
| 552 |
+
if gin_channels != 0:
|
| 553 |
+
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
| 554 |
+
|
| 555 |
+
def forward(self, x, g=None):
|
| 556 |
+
x = self.conv_pre(x)
|
| 557 |
+
if g is not None:
|
| 558 |
+
x = x + self.cond(g)
|
| 559 |
+
|
| 560 |
+
for i in range(self.num_upsamples):
|
| 561 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
| 562 |
+
x = self.ups[i](x)
|
| 563 |
+
xs = None
|
| 564 |
+
for j in range(self.num_kernels):
|
| 565 |
+
if xs is None:
|
| 566 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
| 567 |
+
else:
|
| 568 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
| 569 |
+
x = xs / self.num_kernels
|
| 570 |
+
x = F.leaky_relu(x)
|
| 571 |
+
x = self.conv_post(x)
|
| 572 |
+
x = torch.tanh(x)
|
| 573 |
+
|
| 574 |
+
return x
|
| 575 |
+
|
| 576 |
+
def remove_weight_norm(self):
|
| 577 |
+
print("Removing weight norm...")
|
| 578 |
+
for layer in self.ups:
|
| 579 |
+
remove_weight_norm(layer)
|
| 580 |
+
for layer in self.resblocks:
|
| 581 |
+
layer.remove_weight_norm()
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
class DiscriminatorP(torch.nn.Module):
|
| 585 |
+
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
| 586 |
+
super(DiscriminatorP, self).__init__()
|
| 587 |
+
self.period = period
|
| 588 |
+
self.use_spectral_norm = use_spectral_norm
|
| 589 |
+
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
|
| 590 |
+
self.convs = nn.ModuleList(
|
| 591 |
+
[
|
| 592 |
+
norm_f(
|
| 593 |
+
Conv2d(
|
| 594 |
+
1,
|
| 595 |
+
32,
|
| 596 |
+
(kernel_size, 1),
|
| 597 |
+
(stride, 1),
|
| 598 |
+
padding=(get_padding(kernel_size, 1), 0),
|
| 599 |
+
)
|
| 600 |
+
),
|
| 601 |
+
norm_f(
|
| 602 |
+
Conv2d(
|
| 603 |
+
32,
|
| 604 |
+
128,
|
| 605 |
+
(kernel_size, 1),
|
| 606 |
+
(stride, 1),
|
| 607 |
+
padding=(get_padding(kernel_size, 1), 0),
|
| 608 |
+
)
|
| 609 |
+
),
|
| 610 |
+
norm_f(
|
| 611 |
+
Conv2d(
|
| 612 |
+
128,
|
| 613 |
+
512,
|
| 614 |
+
(kernel_size, 1),
|
| 615 |
+
(stride, 1),
|
| 616 |
+
padding=(get_padding(kernel_size, 1), 0),
|
| 617 |
+
)
|
| 618 |
+
),
|
| 619 |
+
norm_f(
|
| 620 |
+
Conv2d(
|
| 621 |
+
512,
|
| 622 |
+
1024,
|
| 623 |
+
(kernel_size, 1),
|
| 624 |
+
(stride, 1),
|
| 625 |
+
padding=(get_padding(kernel_size, 1), 0),
|
| 626 |
+
)
|
| 627 |
+
),
|
| 628 |
+
norm_f(
|
| 629 |
+
Conv2d(
|
| 630 |
+
1024,
|
| 631 |
+
1024,
|
| 632 |
+
(kernel_size, 1),
|
| 633 |
+
1,
|
| 634 |
+
padding=(get_padding(kernel_size, 1), 0),
|
| 635 |
+
)
|
| 636 |
+
),
|
| 637 |
+
]
|
| 638 |
+
)
|
| 639 |
+
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
| 640 |
+
|
| 641 |
+
def forward(self, x):
|
| 642 |
+
fmap = []
|
| 643 |
+
|
| 644 |
+
# 1d to 2d
|
| 645 |
+
b, c, t = x.shape
|
| 646 |
+
if t % self.period != 0: # pad first
|
| 647 |
+
n_pad = self.period - (t % self.period)
|
| 648 |
+
x = F.pad(x, (0, n_pad), "reflect")
|
| 649 |
+
t = t + n_pad
|
| 650 |
+
x = x.view(b, c, t // self.period, self.period)
|
| 651 |
+
|
| 652 |
+
for layer in self.convs:
|
| 653 |
+
x = layer(x)
|
| 654 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
| 655 |
+
fmap.append(x)
|
| 656 |
+
x = self.conv_post(x)
|
| 657 |
+
fmap.append(x)
|
| 658 |
+
x = torch.flatten(x, 1, -1)
|
| 659 |
+
|
| 660 |
+
return x, fmap
|
| 661 |
+
|
| 662 |
+
|
| 663 |
+
class DiscriminatorS(torch.nn.Module):
|
| 664 |
+
def __init__(self, use_spectral_norm=False):
|
| 665 |
+
super(DiscriminatorS, self).__init__()
|
| 666 |
+
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
|
| 667 |
+
self.convs = nn.ModuleList(
|
| 668 |
+
[
|
| 669 |
+
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
|
| 670 |
+
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
|
| 671 |
+
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
|
| 672 |
+
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
|
| 673 |
+
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
|
| 674 |
+
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
| 675 |
+
]
|
| 676 |
+
)
|
| 677 |
+
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
| 678 |
+
|
| 679 |
+
def forward(self, x):
|
| 680 |
+
fmap = []
|
| 681 |
+
|
| 682 |
+
for layer in self.convs:
|
| 683 |
+
x = layer(x)
|
| 684 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
| 685 |
+
fmap.append(x)
|
| 686 |
+
x = self.conv_post(x)
|
| 687 |
+
fmap.append(x)
|
| 688 |
+
x = torch.flatten(x, 1, -1)
|
| 689 |
+
|
| 690 |
+
return x, fmap
|
| 691 |
+
|
| 692 |
+
|
| 693 |
+
class MultiPeriodDiscriminator(torch.nn.Module):
|
| 694 |
+
def __init__(self, use_spectral_norm=False):
|
| 695 |
+
super(MultiPeriodDiscriminator, self).__init__()
|
| 696 |
+
periods = [2, 3, 5, 7, 11]
|
| 697 |
+
|
| 698 |
+
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
|
| 699 |
+
discs = discs + [
|
| 700 |
+
DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
|
| 701 |
+
]
|
| 702 |
+
self.discriminators = nn.ModuleList(discs)
|
| 703 |
+
|
| 704 |
+
def forward(self, y, y_hat):
|
| 705 |
+
y_d_rs = []
|
| 706 |
+
y_d_gs = []
|
| 707 |
+
fmap_rs = []
|
| 708 |
+
fmap_gs = []
|
| 709 |
+
for i, d in enumerate(self.discriminators):
|
| 710 |
+
y_d_r, fmap_r = d(y)
|
| 711 |
+
y_d_g, fmap_g = d(y_hat)
|
| 712 |
+
y_d_rs.append(y_d_r)
|
| 713 |
+
y_d_gs.append(y_d_g)
|
| 714 |
+
fmap_rs.append(fmap_r)
|
| 715 |
+
fmap_gs.append(fmap_g)
|
| 716 |
+
|
| 717 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
| 718 |
+
|
| 719 |
+
|
| 720 |
+
class ReferenceEncoder(nn.Module):
|
| 721 |
+
"""
|
| 722 |
+
inputs --- [N, Ty/r, n_mels*r] mels
|
| 723 |
+
outputs --- [N, ref_enc_gru_size]
|
| 724 |
+
"""
|
| 725 |
+
|
| 726 |
+
def __init__(self, spec_channels, gin_channels=0):
|
| 727 |
+
super().__init__()
|
| 728 |
+
self.spec_channels = spec_channels
|
| 729 |
+
ref_enc_filters = [32, 32, 64, 64, 128, 128]
|
| 730 |
+
K = len(ref_enc_filters)
|
| 731 |
+
filters = [1] + ref_enc_filters
|
| 732 |
+
convs = [
|
| 733 |
+
weight_norm(
|
| 734 |
+
nn.Conv2d(
|
| 735 |
+
in_channels=filters[i],
|
| 736 |
+
out_channels=filters[i + 1],
|
| 737 |
+
kernel_size=(3, 3),
|
| 738 |
+
stride=(2, 2),
|
| 739 |
+
padding=(1, 1),
|
| 740 |
+
)
|
| 741 |
+
)
|
| 742 |
+
for i in range(K)
|
| 743 |
+
]
|
| 744 |
+
self.convs = nn.ModuleList(convs)
|
| 745 |
+
# self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)]) # noqa: E501
|
| 746 |
+
|
| 747 |
+
out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
|
| 748 |
+
self.gru = nn.GRU(
|
| 749 |
+
input_size=ref_enc_filters[-1] * out_channels,
|
| 750 |
+
hidden_size=256 // 2,
|
| 751 |
+
batch_first=True,
|
| 752 |
+
)
|
| 753 |
+
self.proj = nn.Linear(128, gin_channels)
|
| 754 |
+
|
| 755 |
+
def forward(self, inputs, mask=None):
|
| 756 |
+
N = inputs.size(0)
|
| 757 |
+
out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
|
| 758 |
+
for conv in self.convs:
|
| 759 |
+
out = conv(out)
|
| 760 |
+
# out = wn(out)
|
| 761 |
+
out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
|
| 762 |
+
|
| 763 |
+
out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
|
| 764 |
+
T = out.size(1)
|
| 765 |
+
N = out.size(0)
|
| 766 |
+
out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
|
| 767 |
+
|
| 768 |
+
self.gru.flatten_parameters()
|
| 769 |
+
memory, out = self.gru(out) # out --- [1, N, 128]
|
| 770 |
+
|
| 771 |
+
return self.proj(out.squeeze(0))
|
| 772 |
+
|
| 773 |
+
def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
|
| 774 |
+
for i in range(n_convs):
|
| 775 |
+
L = (L - kernel_size + 2 * pad) // stride + 1
|
| 776 |
+
return L
|
| 777 |
+
|
| 778 |
+
|
| 779 |
+
class SynthesizerTrn(nn.Module):
|
| 780 |
+
"""
|
| 781 |
+
Synthesizer for Training
|
| 782 |
+
"""
|
| 783 |
+
|
| 784 |
+
def __init__(
|
| 785 |
+
self,
|
| 786 |
+
n_vocab,
|
| 787 |
+
spec_channels,
|
| 788 |
+
segment_size,
|
| 789 |
+
inter_channels,
|
| 790 |
+
hidden_channels,
|
| 791 |
+
filter_channels,
|
| 792 |
+
n_heads,
|
| 793 |
+
n_layers,
|
| 794 |
+
kernel_size,
|
| 795 |
+
p_dropout,
|
| 796 |
+
resblock,
|
| 797 |
+
resblock_kernel_sizes,
|
| 798 |
+
resblock_dilation_sizes,
|
| 799 |
+
upsample_rates,
|
| 800 |
+
upsample_initial_channel,
|
| 801 |
+
upsample_kernel_sizes,
|
| 802 |
+
n_speakers=256,
|
| 803 |
+
gin_channels=256,
|
| 804 |
+
use_sdp=True,
|
| 805 |
+
n_flow_layer=4,
|
| 806 |
+
n_layers_trans_flow=4,
|
| 807 |
+
flow_share_parameter=False,
|
| 808 |
+
use_transformer_flow=True,
|
| 809 |
+
**kwargs,
|
| 810 |
+
):
|
| 811 |
+
super().__init__()
|
| 812 |
+
self.n_vocab = n_vocab
|
| 813 |
+
self.spec_channels = spec_channels
|
| 814 |
+
self.inter_channels = inter_channels
|
| 815 |
+
self.hidden_channels = hidden_channels
|
| 816 |
+
self.filter_channels = filter_channels
|
| 817 |
+
self.n_heads = n_heads
|
| 818 |
+
self.n_layers = n_layers
|
| 819 |
+
self.kernel_size = kernel_size
|
| 820 |
+
self.p_dropout = p_dropout
|
| 821 |
+
self.resblock = resblock
|
| 822 |
+
self.resblock_kernel_sizes = resblock_kernel_sizes
|
| 823 |
+
self.resblock_dilation_sizes = resblock_dilation_sizes
|
| 824 |
+
self.upsample_rates = upsample_rates
|
| 825 |
+
self.upsample_initial_channel = upsample_initial_channel
|
| 826 |
+
self.upsample_kernel_sizes = upsample_kernel_sizes
|
| 827 |
+
self.segment_size = segment_size
|
| 828 |
+
self.n_speakers = n_speakers
|
| 829 |
+
self.gin_channels = gin_channels
|
| 830 |
+
self.n_layers_trans_flow = n_layers_trans_flow
|
| 831 |
+
self.use_spk_conditioned_encoder = kwargs.get(
|
| 832 |
+
"use_spk_conditioned_encoder", True
|
| 833 |
+
)
|
| 834 |
+
self.use_sdp = use_sdp
|
| 835 |
+
self.use_noise_scaled_mas = kwargs.get("use_noise_scaled_mas", False)
|
| 836 |
+
self.mas_noise_scale_initial = kwargs.get("mas_noise_scale_initial", 0.01)
|
| 837 |
+
self.noise_scale_delta = kwargs.get("noise_scale_delta", 2e-6)
|
| 838 |
+
self.current_mas_noise_scale = self.mas_noise_scale_initial
|
| 839 |
+
if self.use_spk_conditioned_encoder and gin_channels > 0:
|
| 840 |
+
self.enc_gin_channels = gin_channels
|
| 841 |
+
self.enc_p = TextEncoder(
|
| 842 |
+
n_vocab,
|
| 843 |
+
inter_channels,
|
| 844 |
+
hidden_channels,
|
| 845 |
+
filter_channels,
|
| 846 |
+
n_heads,
|
| 847 |
+
n_layers,
|
| 848 |
+
kernel_size,
|
| 849 |
+
p_dropout,
|
| 850 |
+
self.n_speakers,
|
| 851 |
+
gin_channels=self.enc_gin_channels,
|
| 852 |
+
)
|
| 853 |
+
self.dec = Generator(
|
| 854 |
+
inter_channels,
|
| 855 |
+
resblock,
|
| 856 |
+
resblock_kernel_sizes,
|
| 857 |
+
resblock_dilation_sizes,
|
| 858 |
+
upsample_rates,
|
| 859 |
+
upsample_initial_channel,
|
| 860 |
+
upsample_kernel_sizes,
|
| 861 |
+
gin_channels=gin_channels,
|
| 862 |
+
)
|
| 863 |
+
self.enc_q = PosteriorEncoder(
|
| 864 |
+
spec_channels,
|
| 865 |
+
inter_channels,
|
| 866 |
+
hidden_channels,
|
| 867 |
+
5,
|
| 868 |
+
1,
|
| 869 |
+
16,
|
| 870 |
+
gin_channels=gin_channels,
|
| 871 |
+
)
|
| 872 |
+
if use_transformer_flow:
|
| 873 |
+
self.flow = TransformerCouplingBlock(
|
| 874 |
+
inter_channels,
|
| 875 |
+
hidden_channels,
|
| 876 |
+
filter_channels,
|
| 877 |
+
n_heads,
|
| 878 |
+
n_layers_trans_flow,
|
| 879 |
+
5,
|
| 880 |
+
p_dropout,
|
| 881 |
+
n_flow_layer,
|
| 882 |
+
gin_channels=gin_channels,
|
| 883 |
+
share_parameter=flow_share_parameter,
|
| 884 |
+
)
|
| 885 |
+
else:
|
| 886 |
+
self.flow = ResidualCouplingBlock(
|
| 887 |
+
inter_channels,
|
| 888 |
+
hidden_channels,
|
| 889 |
+
5,
|
| 890 |
+
1,
|
| 891 |
+
n_flow_layer,
|
| 892 |
+
gin_channels=gin_channels,
|
| 893 |
+
)
|
| 894 |
+
self.sdp = StochasticDurationPredictor(
|
| 895 |
+
hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels
|
| 896 |
+
)
|
| 897 |
+
self.dp = DurationPredictor(
|
| 898 |
+
hidden_channels, 256, 3, 0.5, gin_channels=gin_channels
|
| 899 |
+
)
|
| 900 |
+
|
| 901 |
+
if n_speakers >= 1:
|
| 902 |
+
self.emb_g = nn.Embedding(n_speakers, gin_channels)
|
| 903 |
+
else:
|
| 904 |
+
self.ref_enc = ReferenceEncoder(spec_channels, gin_channels)
|
| 905 |
+
|
| 906 |
+
def export_onnx(
|
| 907 |
+
self,
|
| 908 |
+
path,
|
| 909 |
+
max_len=None,
|
| 910 |
+
sdp_ratio=0,
|
| 911 |
+
y=None,
|
| 912 |
+
):
|
| 913 |
+
noise_scale = 0.667
|
| 914 |
+
length_scale = 1
|
| 915 |
+
noise_scale_w = 0.8
|
| 916 |
+
x = (
|
| 917 |
+
torch.LongTensor(
|
| 918 |
+
[
|
| 919 |
+
0,
|
| 920 |
+
97,
|
| 921 |
+
0,
|
| 922 |
+
8,
|
| 923 |
+
0,
|
| 924 |
+
78,
|
| 925 |
+
0,
|
| 926 |
+
8,
|
| 927 |
+
0,
|
| 928 |
+
76,
|
| 929 |
+
0,
|
| 930 |
+
37,
|
| 931 |
+
0,
|
| 932 |
+
40,
|
| 933 |
+
0,
|
| 934 |
+
97,
|
| 935 |
+
0,
|
| 936 |
+
8,
|
| 937 |
+
0,
|
| 938 |
+
23,
|
| 939 |
+
0,
|
| 940 |
+
8,
|
| 941 |
+
0,
|
| 942 |
+
74,
|
| 943 |
+
0,
|
| 944 |
+
26,
|
| 945 |
+
0,
|
| 946 |
+
104,
|
| 947 |
+
0,
|
| 948 |
+
]
|
| 949 |
+
)
|
| 950 |
+
.unsqueeze(0)
|
| 951 |
+
.cpu()
|
| 952 |
+
)
|
| 953 |
+
tone = torch.zeros_like(x).cpu()
|
| 954 |
+
language = torch.zeros_like(x).cpu()
|
| 955 |
+
x_lengths = torch.LongTensor([x.shape[1]]).cpu()
|
| 956 |
+
sid = torch.LongTensor([0]).cpu()
|
| 957 |
+
bert = torch.randn(size=(x.shape[1], 1024)).cpu()
|
| 958 |
+
ja_bert = torch.randn(size=(x.shape[1], 1024)).cpu()
|
| 959 |
+
en_bert = torch.randn(size=(x.shape[1], 1024)).cpu()
|
| 960 |
+
|
| 961 |
+
if self.n_speakers > 0:
|
| 962 |
+
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
| 963 |
+
torch.onnx.export(
|
| 964 |
+
self.emb_g,
|
| 965 |
+
(sid),
|
| 966 |
+
f"onnx/{path}/{path}_emb.onnx",
|
| 967 |
+
input_names=["sid"],
|
| 968 |
+
output_names=["g"],
|
| 969 |
+
verbose=True,
|
| 970 |
+
)
|
| 971 |
+
else:
|
| 972 |
+
g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
|
| 973 |
+
|
| 974 |
+
emo = torch.randn(512, 1)
|
| 975 |
+
|
| 976 |
+
torch.onnx.export(
|
| 977 |
+
self.enc_p,
|
| 978 |
+
(x, x_lengths, tone, language, bert, ja_bert, en_bert, emo, g),
|
| 979 |
+
f"onnx/{path}/{path}_enc_p.onnx",
|
| 980 |
+
input_names=[
|
| 981 |
+
"x",
|
| 982 |
+
"x_lengths",
|
| 983 |
+
"t",
|
| 984 |
+
"language",
|
| 985 |
+
"bert_0",
|
| 986 |
+
"bert_1",
|
| 987 |
+
"bert_2",
|
| 988 |
+
"emo",
|
| 989 |
+
"g",
|
| 990 |
+
],
|
| 991 |
+
output_names=["xout", "m_p", "logs_p", "x_mask"],
|
| 992 |
+
dynamic_axes={
|
| 993 |
+
"x": [0, 1],
|
| 994 |
+
"t": [0, 1],
|
| 995 |
+
"language": [0, 1],
|
| 996 |
+
"bert_0": [0],
|
| 997 |
+
"bert_1": [0],
|
| 998 |
+
"bert_2": [0],
|
| 999 |
+
"xout": [0, 2],
|
| 1000 |
+
"m_p": [0, 2],
|
| 1001 |
+
"logs_p": [0, 2],
|
| 1002 |
+
"x_mask": [0, 2],
|
| 1003 |
+
},
|
| 1004 |
+
verbose=True,
|
| 1005 |
+
opset_version=16,
|
| 1006 |
+
)
|
| 1007 |
+
|
| 1008 |
+
x, m_p, logs_p, x_mask = self.enc_p(
|
| 1009 |
+
x, x_lengths, tone, language, bert, ja_bert, en_bert, emo, g
|
| 1010 |
+
)
|
| 1011 |
+
|
| 1012 |
+
zinput = (
|
| 1013 |
+
torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
|
| 1014 |
+
* noise_scale_w
|
| 1015 |
+
)
|
| 1016 |
+
torch.onnx.export(
|
| 1017 |
+
self.sdp,
|
| 1018 |
+
(x, x_mask, zinput, g),
|
| 1019 |
+
f"onnx/{path}/{path}_sdp.onnx",
|
| 1020 |
+
input_names=["x", "x_mask", "zin", "g"],
|
| 1021 |
+
output_names=["logw"],
|
| 1022 |
+
dynamic_axes={"x": [0, 2], "x_mask": [0, 2], "zin": [0, 2], "logw": [0, 2]},
|
| 1023 |
+
verbose=True,
|
| 1024 |
+
)
|
| 1025 |
+
torch.onnx.export(
|
| 1026 |
+
self.dp,
|
| 1027 |
+
(x, x_mask, g),
|
| 1028 |
+
f"onnx/{path}/{path}_dp.onnx",
|
| 1029 |
+
input_names=["x", "x_mask", "g"],
|
| 1030 |
+
output_names=["logw"],
|
| 1031 |
+
dynamic_axes={"x": [0, 2], "x_mask": [0, 2], "logw": [0, 2]},
|
| 1032 |
+
verbose=True,
|
| 1033 |
+
)
|
| 1034 |
+
logw = self.sdp(x, x_mask, zinput, g=g) * (sdp_ratio) + self.dp(
|
| 1035 |
+
x, x_mask, g=g
|
| 1036 |
+
) * (1 - sdp_ratio)
|
| 1037 |
+
w = torch.exp(logw) * x_mask * length_scale
|
| 1038 |
+
w_ceil = torch.ceil(w)
|
| 1039 |
+
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
| 1040 |
+
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(
|
| 1041 |
+
x_mask.dtype
|
| 1042 |
+
)
|
| 1043 |
+
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
|
| 1044 |
+
attn = commons.generate_path(w_ceil, attn_mask)
|
| 1045 |
+
|
| 1046 |
+
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(
|
| 1047 |
+
1, 2
|
| 1048 |
+
) # [b, t', t], [b, t, d] -> [b, d, t']
|
| 1049 |
+
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(
|
| 1050 |
+
1, 2
|
| 1051 |
+
) # [b, t', t], [b, t, d] -> [b, d, t']
|
| 1052 |
+
|
| 1053 |
+
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
| 1054 |
+
torch.onnx.export(
|
| 1055 |
+
self.flow,
|
| 1056 |
+
(z_p, y_mask, g),
|
| 1057 |
+
f"onnx/{path}/{path}_flow.onnx",
|
| 1058 |
+
input_names=["z_p", "y_mask", "g"],
|
| 1059 |
+
output_names=["z"],
|
| 1060 |
+
dynamic_axes={"z_p": [0, 2], "y_mask": [0, 2], "z": [0, 2]},
|
| 1061 |
+
verbose=True,
|
| 1062 |
+
)
|
| 1063 |
+
|
| 1064 |
+
z = self.flow(z_p, y_mask, g=g, reverse=True)
|
| 1065 |
+
z_in = (z * y_mask)[:, :, :max_len]
|
| 1066 |
+
|
| 1067 |
+
torch.onnx.export(
|
| 1068 |
+
self.dec,
|
| 1069 |
+
(z_in, g),
|
| 1070 |
+
f"onnx/{path}/{path}_dec.onnx",
|
| 1071 |
+
input_names=["z_in", "g"],
|
| 1072 |
+
output_names=["o"],
|
| 1073 |
+
dynamic_axes={"z_in": [0, 2], "o": [0, 2]},
|
| 1074 |
+
verbose=True,
|
| 1075 |
+
)
|
| 1076 |
+
o = self.dec((z * y_mask)[:, :, :max_len], g=g)
|
onnx_modules/V220/text/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .symbols import *
|
onnx_modules/V220/text/symbols.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
punctuation = ["!", "?", "…", ",", ".", "'", "-"]
|
| 2 |
+
pu_symbols = punctuation + ["SP", "UNK"]
|
| 3 |
+
pad = "_"
|
| 4 |
+
|
| 5 |
+
# chinese
|
| 6 |
+
zh_symbols = [
|
| 7 |
+
"E",
|
| 8 |
+
"En",
|
| 9 |
+
"a",
|
| 10 |
+
"ai",
|
| 11 |
+
"an",
|
| 12 |
+
"ang",
|
| 13 |
+
"ao",
|
| 14 |
+
"b",
|
| 15 |
+
"c",
|
| 16 |
+
"ch",
|
| 17 |
+
"d",
|
| 18 |
+
"e",
|
| 19 |
+
"ei",
|
| 20 |
+
"en",
|
| 21 |
+
"eng",
|
| 22 |
+
"er",
|
| 23 |
+
"f",
|
| 24 |
+
"g",
|
| 25 |
+
"h",
|
| 26 |
+
"i",
|
| 27 |
+
"i0",
|
| 28 |
+
"ia",
|
| 29 |
+
"ian",
|
| 30 |
+
"iang",
|
| 31 |
+
"iao",
|
| 32 |
+
"ie",
|
| 33 |
+
"in",
|
| 34 |
+
"ing",
|
| 35 |
+
"iong",
|
| 36 |
+
"ir",
|
| 37 |
+
"iu",
|
| 38 |
+
"j",
|
| 39 |
+
"k",
|
| 40 |
+
"l",
|
| 41 |
+
"m",
|
| 42 |
+
"n",
|
| 43 |
+
"o",
|
| 44 |
+
"ong",
|
| 45 |
+
"ou",
|
| 46 |
+
"p",
|
| 47 |
+
"q",
|
| 48 |
+
"r",
|
| 49 |
+
"s",
|
| 50 |
+
"sh",
|
| 51 |
+
"t",
|
| 52 |
+
"u",
|
| 53 |
+
"ua",
|
| 54 |
+
"uai",
|
| 55 |
+
"uan",
|
| 56 |
+
"uang",
|
| 57 |
+
"ui",
|
| 58 |
+
"un",
|
| 59 |
+
"uo",
|
| 60 |
+
"v",
|
| 61 |
+
"van",
|
| 62 |
+
"ve",
|
| 63 |
+
"vn",
|
| 64 |
+
"w",
|
| 65 |
+
"x",
|
| 66 |
+
"y",
|
| 67 |
+
"z",
|
| 68 |
+
"zh",
|
| 69 |
+
"AA",
|
| 70 |
+
"EE",
|
| 71 |
+
"OO",
|
| 72 |
+
]
|
| 73 |
+
num_zh_tones = 6
|
| 74 |
+
|
| 75 |
+
# japanese
|
| 76 |
+
ja_symbols = [
|
| 77 |
+
"N",
|
| 78 |
+
"a",
|
| 79 |
+
"a:",
|
| 80 |
+
"b",
|
| 81 |
+
"by",
|
| 82 |
+
"ch",
|
| 83 |
+
"d",
|
| 84 |
+
"dy",
|
| 85 |
+
"e",
|
| 86 |
+
"e:",
|
| 87 |
+
"f",
|
| 88 |
+
"g",
|
| 89 |
+
"gy",
|
| 90 |
+
"h",
|
| 91 |
+
"hy",
|
| 92 |
+
"i",
|
| 93 |
+
"i:",
|
| 94 |
+
"j",
|
| 95 |
+
"k",
|
| 96 |
+
"ky",
|
| 97 |
+
"m",
|
| 98 |
+
"my",
|
| 99 |
+
"n",
|
| 100 |
+
"ny",
|
| 101 |
+
"o",
|
| 102 |
+
"o:",
|
| 103 |
+
"p",
|
| 104 |
+
"py",
|
| 105 |
+
"q",
|
| 106 |
+
"r",
|
| 107 |
+
"ry",
|
| 108 |
+
"s",
|
| 109 |
+
"sh",
|
| 110 |
+
"t",
|
| 111 |
+
"ts",
|
| 112 |
+
"ty",
|
| 113 |
+
"u",
|
| 114 |
+
"u:",
|
| 115 |
+
"w",
|
| 116 |
+
"y",
|
| 117 |
+
"z",
|
| 118 |
+
"zy",
|
| 119 |
+
]
|
| 120 |
+
num_ja_tones = 2
|
| 121 |
+
|
| 122 |
+
# English
|
| 123 |
+
en_symbols = [
|
| 124 |
+
"aa",
|
| 125 |
+
"ae",
|
| 126 |
+
"ah",
|
| 127 |
+
"ao",
|
| 128 |
+
"aw",
|
| 129 |
+
"ay",
|
| 130 |
+
"b",
|
| 131 |
+
"ch",
|
| 132 |
+
"d",
|
| 133 |
+
"dh",
|
| 134 |
+
"eh",
|
| 135 |
+
"er",
|
| 136 |
+
"ey",
|
| 137 |
+
"f",
|
| 138 |
+
"g",
|
| 139 |
+
"hh",
|
| 140 |
+
"ih",
|
| 141 |
+
"iy",
|
| 142 |
+
"jh",
|
| 143 |
+
"k",
|
| 144 |
+
"l",
|
| 145 |
+
"m",
|
| 146 |
+
"n",
|
| 147 |
+
"ng",
|
| 148 |
+
"ow",
|
| 149 |
+
"oy",
|
| 150 |
+
"p",
|
| 151 |
+
"r",
|
| 152 |
+
"s",
|
| 153 |
+
"sh",
|
| 154 |
+
"t",
|
| 155 |
+
"th",
|
| 156 |
+
"uh",
|
| 157 |
+
"uw",
|
| 158 |
+
"V",
|
| 159 |
+
"w",
|
| 160 |
+
"y",
|
| 161 |
+
"z",
|
| 162 |
+
"zh",
|
| 163 |
+
]
|
| 164 |
+
num_en_tones = 4
|
| 165 |
+
|
| 166 |
+
# combine all symbols
|
| 167 |
+
normal_symbols = sorted(set(zh_symbols + ja_symbols + en_symbols))
|
| 168 |
+
symbols = [pad] + normal_symbols + pu_symbols
|
| 169 |
+
sil_phonemes_ids = [symbols.index(i) for i in pu_symbols]
|
| 170 |
+
|
| 171 |
+
# combine all tones
|
| 172 |
+
num_tones = num_zh_tones + num_ja_tones + num_en_tones
|
| 173 |
+
|
| 174 |
+
# language maps
|
| 175 |
+
language_id_map = {"ZH": 0, "JP": 1, "EN": 2}
|
| 176 |
+
num_languages = len(language_id_map.keys())
|
| 177 |
+
|
| 178 |
+
language_tone_start_map = {
|
| 179 |
+
"ZH": 0,
|
| 180 |
+
"JP": num_zh_tones,
|
| 181 |
+
"EN": num_zh_tones + num_ja_tones,
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
if __name__ == "__main__":
|
| 185 |
+
a = set(zh_symbols)
|
| 186 |
+
b = set(en_symbols)
|
| 187 |
+
print(sorted(a & b))
|
onnx_modules/V220_OnnxInference/__init__.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import onnxruntime as ort
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def convert_pad_shape(pad_shape):
|
| 6 |
+
layer = pad_shape[::-1]
|
| 7 |
+
pad_shape = [item for sublist in layer for item in sublist]
|
| 8 |
+
return pad_shape
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def sequence_mask(length, max_length=None):
|
| 12 |
+
if max_length is None:
|
| 13 |
+
max_length = length.max()
|
| 14 |
+
x = np.arange(max_length, dtype=length.dtype)
|
| 15 |
+
return np.expand_dims(x, 0) < np.expand_dims(length, 1)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def generate_path(duration, mask):
|
| 19 |
+
"""
|
| 20 |
+
duration: [b, 1, t_x]
|
| 21 |
+
mask: [b, 1, t_y, t_x]
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
b, _, t_y, t_x = mask.shape
|
| 25 |
+
cum_duration = np.cumsum(duration, -1)
|
| 26 |
+
|
| 27 |
+
cum_duration_flat = cum_duration.reshape(b * t_x)
|
| 28 |
+
path = sequence_mask(cum_duration_flat, t_y)
|
| 29 |
+
path = path.reshape(b, t_x, t_y)
|
| 30 |
+
path = path ^ np.pad(path, ((0, 0), (1, 0), (0, 0)))[:, :-1]
|
| 31 |
+
path = np.expand_dims(path, 1).transpose(0, 1, 3, 2)
|
| 32 |
+
return path
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class OnnxInferenceSession:
|
| 36 |
+
def __init__(self, path, Providers=["CPUExecutionProvider"]):
|
| 37 |
+
self.enc = ort.InferenceSession(path["enc"], providers=Providers)
|
| 38 |
+
self.emb_g = ort.InferenceSession(path["emb_g"], providers=Providers)
|
| 39 |
+
self.dp = ort.InferenceSession(path["dp"], providers=Providers)
|
| 40 |
+
self.sdp = ort.InferenceSession(path["sdp"], providers=Providers)
|
| 41 |
+
self.flow = ort.InferenceSession(path["flow"], providers=Providers)
|
| 42 |
+
self.dec = ort.InferenceSession(path["dec"], providers=Providers)
|
| 43 |
+
|
| 44 |
+
def __call__(
|
| 45 |
+
self,
|
| 46 |
+
seq,
|
| 47 |
+
tone,
|
| 48 |
+
language,
|
| 49 |
+
bert_zh,
|
| 50 |
+
bert_jp,
|
| 51 |
+
bert_en,
|
| 52 |
+
emo,
|
| 53 |
+
sid,
|
| 54 |
+
seed=114514,
|
| 55 |
+
seq_noise_scale=0.8,
|
| 56 |
+
sdp_noise_scale=0.6,
|
| 57 |
+
length_scale=1.0,
|
| 58 |
+
sdp_ratio=0.0,
|
| 59 |
+
):
|
| 60 |
+
if seq.ndim == 1:
|
| 61 |
+
seq = np.expand_dims(seq, 0)
|
| 62 |
+
if tone.ndim == 1:
|
| 63 |
+
tone = np.expand_dims(tone, 0)
|
| 64 |
+
if language.ndim == 1:
|
| 65 |
+
language = np.expand_dims(language, 0)
|
| 66 |
+
assert(seq.ndim == 2,tone.ndim == 2,language.ndim == 2)
|
| 67 |
+
g = self.emb_g.run(
|
| 68 |
+
None,
|
| 69 |
+
{
|
| 70 |
+
"sid": sid.astype(np.int64),
|
| 71 |
+
},
|
| 72 |
+
)[0]
|
| 73 |
+
g = np.expand_dims(g, -1)
|
| 74 |
+
enc_rtn = self.enc.run(
|
| 75 |
+
None,
|
| 76 |
+
{
|
| 77 |
+
"x": seq.astype(np.int64),
|
| 78 |
+
"t": tone.astype(np.int64),
|
| 79 |
+
"language": language.astype(np.int64),
|
| 80 |
+
"bert_0": bert_zh.astype(np.float32),
|
| 81 |
+
"bert_1": bert_jp.astype(np.float32),
|
| 82 |
+
"bert_2": bert_en.astype(np.float32),
|
| 83 |
+
"emo": emo.astype(np.float32),
|
| 84 |
+
"g": g.astype(np.float32),
|
| 85 |
+
},
|
| 86 |
+
)
|
| 87 |
+
x, m_p, logs_p, x_mask = enc_rtn[0], enc_rtn[1], enc_rtn[2], enc_rtn[3]
|
| 88 |
+
np.random.seed(seed)
|
| 89 |
+
zinput = np.random.randn(x.shape[0], 2, x.shape[2]) * sdp_noise_scale
|
| 90 |
+
logw = self.sdp.run(
|
| 91 |
+
None, {"x": x, "x_mask": x_mask, "zin": zinput.astype(np.float32), "g": g}
|
| 92 |
+
)[0] * (sdp_ratio) + self.dp.run(None, {"x": x, "x_mask": x_mask, "g": g})[
|
| 93 |
+
0
|
| 94 |
+
] * (
|
| 95 |
+
1 - sdp_ratio
|
| 96 |
+
)
|
| 97 |
+
w = np.exp(logw) * x_mask * length_scale
|
| 98 |
+
w_ceil = np.ceil(w)
|
| 99 |
+
y_lengths = np.clip(np.sum(w_ceil, (1, 2)), a_min=1.0, a_max=100000).astype(
|
| 100 |
+
np.int64
|
| 101 |
+
)
|
| 102 |
+
y_mask = np.expand_dims(sequence_mask(y_lengths, None), 1)
|
| 103 |
+
attn_mask = np.expand_dims(x_mask, 2) * np.expand_dims(y_mask, -1)
|
| 104 |
+
attn = generate_path(w_ceil, attn_mask)
|
| 105 |
+
m_p = np.matmul(attn.squeeze(1), m_p.transpose(0, 2, 1)).transpose(
|
| 106 |
+
0, 2, 1
|
| 107 |
+
) # [b, t', t], [b, t, d] -> [b, d, t']
|
| 108 |
+
logs_p = np.matmul(attn.squeeze(1), logs_p.transpose(0, 2, 1)).transpose(
|
| 109 |
+
0, 2, 1
|
| 110 |
+
) # [b, t', t], [b, t, d] -> [b, d, t']
|
| 111 |
+
|
| 112 |
+
z_p = (
|
| 113 |
+
m_p
|
| 114 |
+
+ np.random.randn(m_p.shape[0], m_p.shape[1], m_p.shape[2])
|
| 115 |
+
* np.exp(logs_p)
|
| 116 |
+
* seq_noise_scale
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
z = self.flow.run(
|
| 120 |
+
None,
|
| 121 |
+
{
|
| 122 |
+
"z_p": z_p.astype(np.float32),
|
| 123 |
+
"y_mask": y_mask.astype(np.float32),
|
| 124 |
+
"g": g,
|
| 125 |
+
},
|
| 126 |
+
)[0]
|
| 127 |
+
|
| 128 |
+
return self.dec.run(None, {"z_in": z.astype(np.float32), "g": g})[0]
|
onnx_modules/V220_novq_dev/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .text.symbols import symbols
|
| 2 |
+
from .models_onnx import SynthesizerTrn
|
| 3 |
+
|
| 4 |
+
__all__ = ["symbols", "SynthesizerTrn"]
|
onnx_modules/V220_novq_dev/attentions_onnx.py
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
from torch.nn import functional as F
|
| 5 |
+
|
| 6 |
+
import commons
|
| 7 |
+
import logging
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class LayerNorm(nn.Module):
|
| 13 |
+
def __init__(self, channels, eps=1e-5):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.channels = channels
|
| 16 |
+
self.eps = eps
|
| 17 |
+
|
| 18 |
+
self.gamma = nn.Parameter(torch.ones(channels))
|
| 19 |
+
self.beta = nn.Parameter(torch.zeros(channels))
|
| 20 |
+
|
| 21 |
+
def forward(self, x):
|
| 22 |
+
x = x.transpose(1, -1)
|
| 23 |
+
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
| 24 |
+
return x.transpose(1, -1)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@torch.jit.script
|
| 28 |
+
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
| 29 |
+
n_channels_int = n_channels[0]
|
| 30 |
+
in_act = input_a + input_b
|
| 31 |
+
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
| 32 |
+
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
| 33 |
+
acts = t_act * s_act
|
| 34 |
+
return acts
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class Encoder(nn.Module):
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
hidden_channels,
|
| 41 |
+
filter_channels,
|
| 42 |
+
n_heads,
|
| 43 |
+
n_layers,
|
| 44 |
+
kernel_size=1,
|
| 45 |
+
p_dropout=0.0,
|
| 46 |
+
window_size=4,
|
| 47 |
+
isflow=True,
|
| 48 |
+
**kwargs
|
| 49 |
+
):
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.hidden_channels = hidden_channels
|
| 52 |
+
self.filter_channels = filter_channels
|
| 53 |
+
self.n_heads = n_heads
|
| 54 |
+
self.n_layers = n_layers
|
| 55 |
+
self.kernel_size = kernel_size
|
| 56 |
+
self.p_dropout = p_dropout
|
| 57 |
+
self.window_size = window_size
|
| 58 |
+
# if isflow:
|
| 59 |
+
# cond_layer = torch.nn.Conv1d(256, 2*hidden_channels*n_layers, 1)
|
| 60 |
+
# self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1)
|
| 61 |
+
# self.cond_layer = weight_norm(cond_layer, name='weight')
|
| 62 |
+
# self.gin_channels = 256
|
| 63 |
+
self.cond_layer_idx = self.n_layers
|
| 64 |
+
if "gin_channels" in kwargs:
|
| 65 |
+
self.gin_channels = kwargs["gin_channels"]
|
| 66 |
+
if self.gin_channels != 0:
|
| 67 |
+
self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
|
| 68 |
+
# vits2 says 3rd block, so idx is 2 by default
|
| 69 |
+
self.cond_layer_idx = (
|
| 70 |
+
kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
|
| 71 |
+
)
|
| 72 |
+
logging.debug(self.gin_channels, self.cond_layer_idx)
|
| 73 |
+
assert (
|
| 74 |
+
self.cond_layer_idx < self.n_layers
|
| 75 |
+
), "cond_layer_idx should be less than n_layers"
|
| 76 |
+
self.drop = nn.Dropout(p_dropout)
|
| 77 |
+
self.attn_layers = nn.ModuleList()
|
| 78 |
+
self.norm_layers_1 = nn.ModuleList()
|
| 79 |
+
self.ffn_layers = nn.ModuleList()
|
| 80 |
+
self.norm_layers_2 = nn.ModuleList()
|
| 81 |
+
for i in range(self.n_layers):
|
| 82 |
+
self.attn_layers.append(
|
| 83 |
+
MultiHeadAttention(
|
| 84 |
+
hidden_channels,
|
| 85 |
+
hidden_channels,
|
| 86 |
+
n_heads,
|
| 87 |
+
p_dropout=p_dropout,
|
| 88 |
+
window_size=window_size,
|
| 89 |
+
)
|
| 90 |
+
)
|
| 91 |
+
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
| 92 |
+
self.ffn_layers.append(
|
| 93 |
+
FFN(
|
| 94 |
+
hidden_channels,
|
| 95 |
+
hidden_channels,
|
| 96 |
+
filter_channels,
|
| 97 |
+
kernel_size,
|
| 98 |
+
p_dropout=p_dropout,
|
| 99 |
+
)
|
| 100 |
+
)
|
| 101 |
+
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
| 102 |
+
|
| 103 |
+
def forward(self, x, x_mask, g=None):
|
| 104 |
+
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
| 105 |
+
x = x * x_mask
|
| 106 |
+
for i in range(self.n_layers):
|
| 107 |
+
if i == self.cond_layer_idx and g is not None:
|
| 108 |
+
g = self.spk_emb_linear(g.transpose(1, 2))
|
| 109 |
+
g = g.transpose(1, 2)
|
| 110 |
+
x = x + g
|
| 111 |
+
x = x * x_mask
|
| 112 |
+
y = self.attn_layers[i](x, x, attn_mask)
|
| 113 |
+
y = self.drop(y)
|
| 114 |
+
x = self.norm_layers_1[i](x + y)
|
| 115 |
+
|
| 116 |
+
y = self.ffn_layers[i](x, x_mask)
|
| 117 |
+
y = self.drop(y)
|
| 118 |
+
x = self.norm_layers_2[i](x + y)
|
| 119 |
+
x = x * x_mask
|
| 120 |
+
return x
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class MultiHeadAttention(nn.Module):
|
| 124 |
+
def __init__(
|
| 125 |
+
self,
|
| 126 |
+
channels,
|
| 127 |
+
out_channels,
|
| 128 |
+
n_heads,
|
| 129 |
+
p_dropout=0.0,
|
| 130 |
+
window_size=None,
|
| 131 |
+
heads_share=True,
|
| 132 |
+
block_length=None,
|
| 133 |
+
proximal_bias=False,
|
| 134 |
+
proximal_init=False,
|
| 135 |
+
):
|
| 136 |
+
super().__init__()
|
| 137 |
+
assert channels % n_heads == 0
|
| 138 |
+
|
| 139 |
+
self.channels = channels
|
| 140 |
+
self.out_channels = out_channels
|
| 141 |
+
self.n_heads = n_heads
|
| 142 |
+
self.p_dropout = p_dropout
|
| 143 |
+
self.window_size = window_size
|
| 144 |
+
self.heads_share = heads_share
|
| 145 |
+
self.block_length = block_length
|
| 146 |
+
self.proximal_bias = proximal_bias
|
| 147 |
+
self.proximal_init = proximal_init
|
| 148 |
+
self.attn = None
|
| 149 |
+
|
| 150 |
+
self.k_channels = channels // n_heads
|
| 151 |
+
self.conv_q = nn.Conv1d(channels, channels, 1)
|
| 152 |
+
self.conv_k = nn.Conv1d(channels, channels, 1)
|
| 153 |
+
self.conv_v = nn.Conv1d(channels, channels, 1)
|
| 154 |
+
self.conv_o = nn.Conv1d(channels, out_channels, 1)
|
| 155 |
+
self.drop = nn.Dropout(p_dropout)
|
| 156 |
+
|
| 157 |
+
if window_size is not None:
|
| 158 |
+
n_heads_rel = 1 if heads_share else n_heads
|
| 159 |
+
rel_stddev = self.k_channels**-0.5
|
| 160 |
+
self.emb_rel_k = nn.Parameter(
|
| 161 |
+
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
| 162 |
+
* rel_stddev
|
| 163 |
+
)
|
| 164 |
+
self.emb_rel_v = nn.Parameter(
|
| 165 |
+
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
| 166 |
+
* rel_stddev
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
nn.init.xavier_uniform_(self.conv_q.weight)
|
| 170 |
+
nn.init.xavier_uniform_(self.conv_k.weight)
|
| 171 |
+
nn.init.xavier_uniform_(self.conv_v.weight)
|
| 172 |
+
if proximal_init:
|
| 173 |
+
with torch.no_grad():
|
| 174 |
+
self.conv_k.weight.copy_(self.conv_q.weight)
|
| 175 |
+
self.conv_k.bias.copy_(self.conv_q.bias)
|
| 176 |
+
|
| 177 |
+
def forward(self, x, c, attn_mask=None):
|
| 178 |
+
q = self.conv_q(x)
|
| 179 |
+
k = self.conv_k(c)
|
| 180 |
+
v = self.conv_v(c)
|
| 181 |
+
|
| 182 |
+
x, self.attn = self.attention(q, k, v, mask=attn_mask)
|
| 183 |
+
|
| 184 |
+
x = self.conv_o(x)
|
| 185 |
+
return x
|
| 186 |
+
|
| 187 |
+
def attention(self, query, key, value, mask=None):
|
| 188 |
+
# reshape [b, d, t] -> [b, n_h, t, d_k]
|
| 189 |
+
b, d, t_s, t_t = (*key.size(), query.size(2))
|
| 190 |
+
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
|
| 191 |
+
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
| 192 |
+
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
| 193 |
+
|
| 194 |
+
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
|
| 195 |
+
if self.window_size is not None:
|
| 196 |
+
assert (
|
| 197 |
+
t_s == t_t
|
| 198 |
+
), "Relative attention is only available for self-attention."
|
| 199 |
+
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
|
| 200 |
+
rel_logits = self._matmul_with_relative_keys(
|
| 201 |
+
query / math.sqrt(self.k_channels), key_relative_embeddings
|
| 202 |
+
)
|
| 203 |
+
scores_local = self._relative_position_to_absolute_position(rel_logits)
|
| 204 |
+
scores = scores + scores_local
|
| 205 |
+
if self.proximal_bias:
|
| 206 |
+
assert t_s == t_t, "Proximal bias is only available for self-attention."
|
| 207 |
+
scores = scores + self._attention_bias_proximal(t_s).to(
|
| 208 |
+
device=scores.device, dtype=scores.dtype
|
| 209 |
+
)
|
| 210 |
+
if mask is not None:
|
| 211 |
+
scores = scores.masked_fill(mask == 0, -1e4)
|
| 212 |
+
if self.block_length is not None:
|
| 213 |
+
assert (
|
| 214 |
+
t_s == t_t
|
| 215 |
+
), "Local attention is only available for self-attention."
|
| 216 |
+
block_mask = (
|
| 217 |
+
torch.ones_like(scores)
|
| 218 |
+
.triu(-self.block_length)
|
| 219 |
+
.tril(self.block_length)
|
| 220 |
+
)
|
| 221 |
+
scores = scores.masked_fill(block_mask == 0, -1e4)
|
| 222 |
+
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
|
| 223 |
+
p_attn = self.drop(p_attn)
|
| 224 |
+
output = torch.matmul(p_attn, value)
|
| 225 |
+
if self.window_size is not None:
|
| 226 |
+
relative_weights = self._absolute_position_to_relative_position(p_attn)
|
| 227 |
+
value_relative_embeddings = self._get_relative_embeddings(
|
| 228 |
+
self.emb_rel_v, t_s
|
| 229 |
+
)
|
| 230 |
+
output = output + self._matmul_with_relative_values(
|
| 231 |
+
relative_weights, value_relative_embeddings
|
| 232 |
+
)
|
| 233 |
+
output = (
|
| 234 |
+
output.transpose(2, 3).contiguous().view(b, d, t_t)
|
| 235 |
+
) # [b, n_h, t_t, d_k] -> [b, d, t_t]
|
| 236 |
+
return output, p_attn
|
| 237 |
+
|
| 238 |
+
def _matmul_with_relative_values(self, x, y):
|
| 239 |
+
"""
|
| 240 |
+
x: [b, h, l, m]
|
| 241 |
+
y: [h or 1, m, d]
|
| 242 |
+
ret: [b, h, l, d]
|
| 243 |
+
"""
|
| 244 |
+
ret = torch.matmul(x, y.unsqueeze(0))
|
| 245 |
+
return ret
|
| 246 |
+
|
| 247 |
+
def _matmul_with_relative_keys(self, x, y):
|
| 248 |
+
"""
|
| 249 |
+
x: [b, h, l, d]
|
| 250 |
+
y: [h or 1, m, d]
|
| 251 |
+
ret: [b, h, l, m]
|
| 252 |
+
"""
|
| 253 |
+
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
|
| 254 |
+
return ret
|
| 255 |
+
|
| 256 |
+
def _get_relative_embeddings(self, relative_embeddings, length):
|
| 257 |
+
max_relative_position = 2 * self.window_size + 1
|
| 258 |
+
# Pad first before slice to avoid using cond ops.
|
| 259 |
+
pad_length = max(length - (self.window_size + 1), 0)
|
| 260 |
+
slice_start_position = max((self.window_size + 1) - length, 0)
|
| 261 |
+
slice_end_position = slice_start_position + 2 * length - 1
|
| 262 |
+
if pad_length > 0:
|
| 263 |
+
padded_relative_embeddings = F.pad(
|
| 264 |
+
relative_embeddings,
|
| 265 |
+
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
|
| 266 |
+
)
|
| 267 |
+
else:
|
| 268 |
+
padded_relative_embeddings = relative_embeddings
|
| 269 |
+
used_relative_embeddings = padded_relative_embeddings[
|
| 270 |
+
:, slice_start_position:slice_end_position
|
| 271 |
+
]
|
| 272 |
+
return used_relative_embeddings
|
| 273 |
+
|
| 274 |
+
def _relative_position_to_absolute_position(self, x):
|
| 275 |
+
"""
|
| 276 |
+
x: [b, h, l, 2*l-1]
|
| 277 |
+
ret: [b, h, l, l]
|
| 278 |
+
"""
|
| 279 |
+
batch, heads, length, _ = x.size()
|
| 280 |
+
# Concat columns of pad to shift from relative to absolute indexing.
|
| 281 |
+
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
|
| 282 |
+
|
| 283 |
+
# Concat extra elements so to add up to shape (len+1, 2*len-1).
|
| 284 |
+
x_flat = x.view([batch, heads, length * 2 * length])
|
| 285 |
+
x_flat = F.pad(
|
| 286 |
+
x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
# Reshape and slice out the padded elements.
|
| 290 |
+
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
|
| 291 |
+
:, :, :length, length - 1 :
|
| 292 |
+
]
|
| 293 |
+
return x_final
|
| 294 |
+
|
| 295 |
+
def _absolute_position_to_relative_position(self, x):
|
| 296 |
+
"""
|
| 297 |
+
x: [b, h, l, l]
|
| 298 |
+
ret: [b, h, l, 2*l-1]
|
| 299 |
+
"""
|
| 300 |
+
batch, heads, length, _ = x.size()
|
| 301 |
+
# padd along column
|
| 302 |
+
x = F.pad(
|
| 303 |
+
x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
|
| 304 |
+
)
|
| 305 |
+
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
|
| 306 |
+
# add 0's in the beginning that will skew the elements after reshape
|
| 307 |
+
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
|
| 308 |
+
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
|
| 309 |
+
return x_final
|
| 310 |
+
|
| 311 |
+
def _attention_bias_proximal(self, length):
|
| 312 |
+
"""Bias for self-attention to encourage attention to close positions.
|
| 313 |
+
Args:
|
| 314 |
+
length: an integer scalar.
|
| 315 |
+
Returns:
|
| 316 |
+
a Tensor with shape [1, 1, length, length]
|
| 317 |
+
"""
|
| 318 |
+
r = torch.arange(length, dtype=torch.float32)
|
| 319 |
+
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
|
| 320 |
+
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
class FFN(nn.Module):
|
| 324 |
+
def __init__(
|
| 325 |
+
self,
|
| 326 |
+
in_channels,
|
| 327 |
+
out_channels,
|
| 328 |
+
filter_channels,
|
| 329 |
+
kernel_size,
|
| 330 |
+
p_dropout=0.0,
|
| 331 |
+
activation=None,
|
| 332 |
+
causal=False,
|
| 333 |
+
):
|
| 334 |
+
super().__init__()
|
| 335 |
+
self.in_channels = in_channels
|
| 336 |
+
self.out_channels = out_channels
|
| 337 |
+
self.filter_channels = filter_channels
|
| 338 |
+
self.kernel_size = kernel_size
|
| 339 |
+
self.p_dropout = p_dropout
|
| 340 |
+
self.activation = activation
|
| 341 |
+
self.causal = causal
|
| 342 |
+
|
| 343 |
+
if causal:
|
| 344 |
+
self.padding = self._causal_padding
|
| 345 |
+
else:
|
| 346 |
+
self.padding = self._same_padding
|
| 347 |
+
|
| 348 |
+
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
|
| 349 |
+
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
|
| 350 |
+
self.drop = nn.Dropout(p_dropout)
|
| 351 |
+
|
| 352 |
+
def forward(self, x, x_mask):
|
| 353 |
+
x = self.conv_1(self.padding(x * x_mask))
|
| 354 |
+
if self.activation == "gelu":
|
| 355 |
+
x = x * torch.sigmoid(1.702 * x)
|
| 356 |
+
else:
|
| 357 |
+
x = torch.relu(x)
|
| 358 |
+
x = self.drop(x)
|
| 359 |
+
x = self.conv_2(self.padding(x * x_mask))
|
| 360 |
+
return x * x_mask
|
| 361 |
+
|
| 362 |
+
def _causal_padding(self, x):
|
| 363 |
+
if self.kernel_size == 1:
|
| 364 |
+
return x
|
| 365 |
+
pad_l = self.kernel_size - 1
|
| 366 |
+
pad_r = 0
|
| 367 |
+
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
| 368 |
+
x = F.pad(x, commons.convert_pad_shape(padding))
|
| 369 |
+
return x
|
| 370 |
+
|
| 371 |
+
def _same_padding(self, x):
|
| 372 |
+
if self.kernel_size == 1:
|
| 373 |
+
return x
|
| 374 |
+
pad_l = (self.kernel_size - 1) // 2
|
| 375 |
+
pad_r = self.kernel_size // 2
|
| 376 |
+
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
| 377 |
+
x = F.pad(x, commons.convert_pad_shape(padding))
|
| 378 |
+
return x
|
onnx_modules/V220_novq_dev/models_onnx.py
ADDED
|
@@ -0,0 +1,1048 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
from torch.nn import functional as F
|
| 5 |
+
|
| 6 |
+
import commons
|
| 7 |
+
import modules
|
| 8 |
+
from . import attentions_onnx
|
| 9 |
+
|
| 10 |
+
from torch.nn import Conv1d, ConvTranspose1d, Conv2d
|
| 11 |
+
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
| 12 |
+
from commons import init_weights, get_padding
|
| 13 |
+
from .text import symbols, num_tones, num_languages
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class DurationDiscriminator(nn.Module): # vits2
|
| 17 |
+
def __init__(
|
| 18 |
+
self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
|
| 19 |
+
):
|
| 20 |
+
super().__init__()
|
| 21 |
+
|
| 22 |
+
self.in_channels = in_channels
|
| 23 |
+
self.filter_channels = filter_channels
|
| 24 |
+
self.kernel_size = kernel_size
|
| 25 |
+
self.p_dropout = p_dropout
|
| 26 |
+
self.gin_channels = gin_channels
|
| 27 |
+
|
| 28 |
+
self.drop = nn.Dropout(p_dropout)
|
| 29 |
+
self.conv_1 = nn.Conv1d(
|
| 30 |
+
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
| 31 |
+
)
|
| 32 |
+
self.norm_1 = modules.LayerNorm(filter_channels)
|
| 33 |
+
self.conv_2 = nn.Conv1d(
|
| 34 |
+
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
| 35 |
+
)
|
| 36 |
+
self.norm_2 = modules.LayerNorm(filter_channels)
|
| 37 |
+
self.dur_proj = nn.Conv1d(1, filter_channels, 1)
|
| 38 |
+
|
| 39 |
+
self.pre_out_conv_1 = nn.Conv1d(
|
| 40 |
+
2 * filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
| 41 |
+
)
|
| 42 |
+
self.pre_out_norm_1 = modules.LayerNorm(filter_channels)
|
| 43 |
+
self.pre_out_conv_2 = nn.Conv1d(
|
| 44 |
+
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
| 45 |
+
)
|
| 46 |
+
self.pre_out_norm_2 = modules.LayerNorm(filter_channels)
|
| 47 |
+
|
| 48 |
+
if gin_channels != 0:
|
| 49 |
+
self.cond = nn.Conv1d(gin_channels, in_channels, 1)
|
| 50 |
+
|
| 51 |
+
self.output_layer = nn.Sequential(nn.Linear(filter_channels, 1), nn.Sigmoid())
|
| 52 |
+
|
| 53 |
+
def forward_probability(self, x, x_mask, dur, g=None):
|
| 54 |
+
dur = self.dur_proj(dur)
|
| 55 |
+
x = torch.cat([x, dur], dim=1)
|
| 56 |
+
x = self.pre_out_conv_1(x * x_mask)
|
| 57 |
+
x = torch.relu(x)
|
| 58 |
+
x = self.pre_out_norm_1(x)
|
| 59 |
+
x = self.drop(x)
|
| 60 |
+
x = self.pre_out_conv_2(x * x_mask)
|
| 61 |
+
x = torch.relu(x)
|
| 62 |
+
x = self.pre_out_norm_2(x)
|
| 63 |
+
x = self.drop(x)
|
| 64 |
+
x = x * x_mask
|
| 65 |
+
x = x.transpose(1, 2)
|
| 66 |
+
output_prob = self.output_layer(x)
|
| 67 |
+
return output_prob
|
| 68 |
+
|
| 69 |
+
def forward(self, x, x_mask, dur_r, dur_hat, g=None):
|
| 70 |
+
x = torch.detach(x)
|
| 71 |
+
if g is not None:
|
| 72 |
+
g = torch.detach(g)
|
| 73 |
+
x = x + self.cond(g)
|
| 74 |
+
x = self.conv_1(x * x_mask)
|
| 75 |
+
x = torch.relu(x)
|
| 76 |
+
x = self.norm_1(x)
|
| 77 |
+
x = self.drop(x)
|
| 78 |
+
x = self.conv_2(x * x_mask)
|
| 79 |
+
x = torch.relu(x)
|
| 80 |
+
x = self.norm_2(x)
|
| 81 |
+
x = self.drop(x)
|
| 82 |
+
|
| 83 |
+
output_probs = []
|
| 84 |
+
for dur in [dur_r, dur_hat]:
|
| 85 |
+
output_prob = self.forward_probability(x, x_mask, dur, g)
|
| 86 |
+
output_probs.append(output_prob)
|
| 87 |
+
|
| 88 |
+
return output_probs
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class TransformerCouplingBlock(nn.Module):
|
| 92 |
+
def __init__(
|
| 93 |
+
self,
|
| 94 |
+
channels,
|
| 95 |
+
hidden_channels,
|
| 96 |
+
filter_channels,
|
| 97 |
+
n_heads,
|
| 98 |
+
n_layers,
|
| 99 |
+
kernel_size,
|
| 100 |
+
p_dropout,
|
| 101 |
+
n_flows=4,
|
| 102 |
+
gin_channels=0,
|
| 103 |
+
share_parameter=False,
|
| 104 |
+
):
|
| 105 |
+
super().__init__()
|
| 106 |
+
self.channels = channels
|
| 107 |
+
self.hidden_channels = hidden_channels
|
| 108 |
+
self.kernel_size = kernel_size
|
| 109 |
+
self.n_layers = n_layers
|
| 110 |
+
self.n_flows = n_flows
|
| 111 |
+
self.gin_channels = gin_channels
|
| 112 |
+
|
| 113 |
+
self.flows = nn.ModuleList()
|
| 114 |
+
|
| 115 |
+
self.wn = (
|
| 116 |
+
attentions_onnx.FFT(
|
| 117 |
+
hidden_channels,
|
| 118 |
+
filter_channels,
|
| 119 |
+
n_heads,
|
| 120 |
+
n_layers,
|
| 121 |
+
kernel_size,
|
| 122 |
+
p_dropout,
|
| 123 |
+
isflow=True,
|
| 124 |
+
gin_channels=self.gin_channels,
|
| 125 |
+
)
|
| 126 |
+
if share_parameter
|
| 127 |
+
else None
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
for i in range(n_flows):
|
| 131 |
+
self.flows.append(
|
| 132 |
+
modules.TransformerCouplingLayer(
|
| 133 |
+
channels,
|
| 134 |
+
hidden_channels,
|
| 135 |
+
kernel_size,
|
| 136 |
+
n_layers,
|
| 137 |
+
n_heads,
|
| 138 |
+
p_dropout,
|
| 139 |
+
filter_channels,
|
| 140 |
+
mean_only=True,
|
| 141 |
+
wn_sharing_parameter=self.wn,
|
| 142 |
+
gin_channels=self.gin_channels,
|
| 143 |
+
)
|
| 144 |
+
)
|
| 145 |
+
self.flows.append(modules.Flip())
|
| 146 |
+
|
| 147 |
+
def forward(self, x, x_mask, g=None, reverse=True):
|
| 148 |
+
if not reverse:
|
| 149 |
+
for flow in self.flows:
|
| 150 |
+
x, _ = flow(x, x_mask, g=g, reverse=reverse)
|
| 151 |
+
else:
|
| 152 |
+
for flow in reversed(self.flows):
|
| 153 |
+
x = flow(x, x_mask, g=g, reverse=reverse)
|
| 154 |
+
return x
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class StochasticDurationPredictor(nn.Module):
|
| 158 |
+
def __init__(
|
| 159 |
+
self,
|
| 160 |
+
in_channels,
|
| 161 |
+
filter_channels,
|
| 162 |
+
kernel_size,
|
| 163 |
+
p_dropout,
|
| 164 |
+
n_flows=4,
|
| 165 |
+
gin_channels=0,
|
| 166 |
+
):
|
| 167 |
+
super().__init__()
|
| 168 |
+
filter_channels = in_channels # it needs to be removed from future version.
|
| 169 |
+
self.in_channels = in_channels
|
| 170 |
+
self.filter_channels = filter_channels
|
| 171 |
+
self.kernel_size = kernel_size
|
| 172 |
+
self.p_dropout = p_dropout
|
| 173 |
+
self.n_flows = n_flows
|
| 174 |
+
self.gin_channels = gin_channels
|
| 175 |
+
|
| 176 |
+
self.log_flow = modules.Log()
|
| 177 |
+
self.flows = nn.ModuleList()
|
| 178 |
+
self.flows.append(modules.ElementwiseAffine(2))
|
| 179 |
+
for i in range(n_flows):
|
| 180 |
+
self.flows.append(
|
| 181 |
+
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
|
| 182 |
+
)
|
| 183 |
+
self.flows.append(modules.Flip())
|
| 184 |
+
|
| 185 |
+
self.post_pre = nn.Conv1d(1, filter_channels, 1)
|
| 186 |
+
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
| 187 |
+
self.post_convs = modules.DDSConv(
|
| 188 |
+
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
|
| 189 |
+
)
|
| 190 |
+
self.post_flows = nn.ModuleList()
|
| 191 |
+
self.post_flows.append(modules.ElementwiseAffine(2))
|
| 192 |
+
for i in range(4):
|
| 193 |
+
self.post_flows.append(
|
| 194 |
+
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
|
| 195 |
+
)
|
| 196 |
+
self.post_flows.append(modules.Flip())
|
| 197 |
+
|
| 198 |
+
self.pre = nn.Conv1d(in_channels, filter_channels, 1)
|
| 199 |
+
self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
| 200 |
+
self.convs = modules.DDSConv(
|
| 201 |
+
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
|
| 202 |
+
)
|
| 203 |
+
if gin_channels != 0:
|
| 204 |
+
self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
|
| 205 |
+
|
| 206 |
+
def forward(self, x, x_mask, z, g=None):
|
| 207 |
+
x = torch.detach(x)
|
| 208 |
+
x = self.pre(x)
|
| 209 |
+
if g is not None:
|
| 210 |
+
g = torch.detach(g)
|
| 211 |
+
x = x + self.cond(g)
|
| 212 |
+
x = self.convs(x, x_mask)
|
| 213 |
+
x = self.proj(x) * x_mask
|
| 214 |
+
|
| 215 |
+
flows = list(reversed(self.flows))
|
| 216 |
+
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
|
| 217 |
+
for flow in flows:
|
| 218 |
+
z = flow(z, x_mask, g=x, reverse=True)
|
| 219 |
+
z0, z1 = torch.split(z, [1, 1], 1)
|
| 220 |
+
logw = z0
|
| 221 |
+
return logw
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
class DurationPredictor(nn.Module):
|
| 225 |
+
def __init__(
|
| 226 |
+
self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
|
| 227 |
+
):
|
| 228 |
+
super().__init__()
|
| 229 |
+
|
| 230 |
+
self.in_channels = in_channels
|
| 231 |
+
self.filter_channels = filter_channels
|
| 232 |
+
self.kernel_size = kernel_size
|
| 233 |
+
self.p_dropout = p_dropout
|
| 234 |
+
self.gin_channels = gin_channels
|
| 235 |
+
|
| 236 |
+
self.drop = nn.Dropout(p_dropout)
|
| 237 |
+
self.conv_1 = nn.Conv1d(
|
| 238 |
+
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
| 239 |
+
)
|
| 240 |
+
self.norm_1 = modules.LayerNorm(filter_channels)
|
| 241 |
+
self.conv_2 = nn.Conv1d(
|
| 242 |
+
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
| 243 |
+
)
|
| 244 |
+
self.norm_2 = modules.LayerNorm(filter_channels)
|
| 245 |
+
self.proj = nn.Conv1d(filter_channels, 1, 1)
|
| 246 |
+
|
| 247 |
+
if gin_channels != 0:
|
| 248 |
+
self.cond = nn.Conv1d(gin_channels, in_channels, 1)
|
| 249 |
+
|
| 250 |
+
def forward(self, x, x_mask, g=None):
|
| 251 |
+
x = torch.detach(x)
|
| 252 |
+
if g is not None:
|
| 253 |
+
g = torch.detach(g)
|
| 254 |
+
x = x + self.cond(g)
|
| 255 |
+
x = self.conv_1(x * x_mask)
|
| 256 |
+
x = torch.relu(x)
|
| 257 |
+
x = self.norm_1(x)
|
| 258 |
+
x = self.drop(x)
|
| 259 |
+
x = self.conv_2(x * x_mask)
|
| 260 |
+
x = torch.relu(x)
|
| 261 |
+
x = self.norm_2(x)
|
| 262 |
+
x = self.drop(x)
|
| 263 |
+
x = self.proj(x * x_mask)
|
| 264 |
+
return x * x_mask
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
class Bottleneck(nn.Sequential):
|
| 268 |
+
def __init__(self, in_dim, hidden_dim):
|
| 269 |
+
c_fc1 = nn.Linear(in_dim, hidden_dim, bias=False)
|
| 270 |
+
c_fc2 = nn.Linear(in_dim, hidden_dim, bias=False)
|
| 271 |
+
super().__init__(*[c_fc1, c_fc2])
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
class Block(nn.Module):
|
| 275 |
+
def __init__(self, in_dim, hidden_dim) -> None:
|
| 276 |
+
super().__init__()
|
| 277 |
+
self.norm = nn.LayerNorm(in_dim)
|
| 278 |
+
self.mlp = MLP(in_dim, hidden_dim)
|
| 279 |
+
|
| 280 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 281 |
+
x = x + self.mlp(self.norm(x))
|
| 282 |
+
return x
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
class MLP(nn.Module):
|
| 286 |
+
def __init__(self, in_dim, hidden_dim):
|
| 287 |
+
super().__init__()
|
| 288 |
+
self.c_fc1 = nn.Linear(in_dim, hidden_dim, bias=False)
|
| 289 |
+
self.c_fc2 = nn.Linear(in_dim, hidden_dim, bias=False)
|
| 290 |
+
self.c_proj = nn.Linear(hidden_dim, in_dim, bias=False)
|
| 291 |
+
|
| 292 |
+
def forward(self, x: torch.Tensor):
|
| 293 |
+
x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
|
| 294 |
+
x = self.c_proj(x)
|
| 295 |
+
return x
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
class TextEncoder(nn.Module):
|
| 299 |
+
def __init__(
|
| 300 |
+
self,
|
| 301 |
+
n_vocab,
|
| 302 |
+
out_channels,
|
| 303 |
+
hidden_channels,
|
| 304 |
+
filter_channels,
|
| 305 |
+
n_heads,
|
| 306 |
+
n_layers,
|
| 307 |
+
kernel_size,
|
| 308 |
+
p_dropout,
|
| 309 |
+
n_speakers,
|
| 310 |
+
gin_channels=0,
|
| 311 |
+
):
|
| 312 |
+
super().__init__()
|
| 313 |
+
self.n_vocab = n_vocab
|
| 314 |
+
self.out_channels = out_channels
|
| 315 |
+
self.hidden_channels = hidden_channels
|
| 316 |
+
self.filter_channels = filter_channels
|
| 317 |
+
self.n_heads = n_heads
|
| 318 |
+
self.n_layers = n_layers
|
| 319 |
+
self.kernel_size = kernel_size
|
| 320 |
+
self.p_dropout = p_dropout
|
| 321 |
+
self.gin_channels = gin_channels
|
| 322 |
+
self.emb = nn.Embedding(len(symbols), hidden_channels)
|
| 323 |
+
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
|
| 324 |
+
self.tone_emb = nn.Embedding(num_tones, hidden_channels)
|
| 325 |
+
nn.init.normal_(self.tone_emb.weight, 0.0, hidden_channels**-0.5)
|
| 326 |
+
self.language_emb = nn.Embedding(num_languages, hidden_channels)
|
| 327 |
+
nn.init.normal_(self.language_emb.weight, 0.0, hidden_channels**-0.5)
|
| 328 |
+
self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
|
| 329 |
+
self.ja_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
|
| 330 |
+
self.en_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
|
| 331 |
+
# self.emo_proj = nn.Linear(1024, 1024)
|
| 332 |
+
# self.emo_quantizer = nn.ModuleList()
|
| 333 |
+
# for i in range(0, n_speakers):
|
| 334 |
+
# self.emo_quantizer.append(
|
| 335 |
+
# VectorQuantize(
|
| 336 |
+
# dim=1024,
|
| 337 |
+
# codebook_size=10,
|
| 338 |
+
# decay=0.8,
|
| 339 |
+
# commitment_weight=1.0,
|
| 340 |
+
# learnable_codebook=True,
|
| 341 |
+
# ema_update=False,
|
| 342 |
+
# )
|
| 343 |
+
# )
|
| 344 |
+
# self.emo_q_proj = nn.Linear(1024, hidden_channels)
|
| 345 |
+
self.n_speakers = n_speakers
|
| 346 |
+
self.emo_proj = nn.Linear(512, hidden_channels)
|
| 347 |
+
|
| 348 |
+
self.encoder = attentions_onnx.Encoder(
|
| 349 |
+
hidden_channels,
|
| 350 |
+
filter_channels,
|
| 351 |
+
n_heads,
|
| 352 |
+
n_layers,
|
| 353 |
+
kernel_size,
|
| 354 |
+
p_dropout,
|
| 355 |
+
gin_channels=self.gin_channels,
|
| 356 |
+
)
|
| 357 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
| 358 |
+
|
| 359 |
+
def forward(
|
| 360 |
+
self, x, x_lengths, tone, language, bert, ja_bert, en_bert, emo, g=None
|
| 361 |
+
):
|
| 362 |
+
x_mask = torch.ones_like(x).unsqueeze(0)
|
| 363 |
+
bert_emb = self.bert_proj(bert.transpose(0, 1).unsqueeze(0)).transpose(1, 2)
|
| 364 |
+
ja_bert_emb = self.ja_bert_proj(ja_bert.transpose(0, 1).unsqueeze(0)).transpose(
|
| 365 |
+
1, 2
|
| 366 |
+
)
|
| 367 |
+
en_bert_emb = self.en_bert_proj(en_bert.transpose(0, 1).unsqueeze(0)).transpose(
|
| 368 |
+
1, 2
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
x = (
|
| 372 |
+
self.emb(x)
|
| 373 |
+
+ self.tone_emb(tone)
|
| 374 |
+
+ self.language_emb(language)
|
| 375 |
+
+ bert_emb
|
| 376 |
+
+ ja_bert_emb
|
| 377 |
+
+ en_bert_emb
|
| 378 |
+
+ self.emo_proj(emo)
|
| 379 |
+
) * math.sqrt(
|
| 380 |
+
self.hidden_channels
|
| 381 |
+
) # [b, t, h]
|
| 382 |
+
x = torch.transpose(x, 1, -1) # [b, h, t]
|
| 383 |
+
x_mask = x_mask.to(x.dtype)
|
| 384 |
+
|
| 385 |
+
x = self.encoder(x * x_mask, x_mask, g=g)
|
| 386 |
+
stats = self.proj(x) * x_mask
|
| 387 |
+
|
| 388 |
+
m, logs = torch.split(stats, self.out_channels, dim=1)
|
| 389 |
+
return x, m, logs, x_mask
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
class ResidualCouplingBlock(nn.Module):
|
| 393 |
+
def __init__(
|
| 394 |
+
self,
|
| 395 |
+
channels,
|
| 396 |
+
hidden_channels,
|
| 397 |
+
kernel_size,
|
| 398 |
+
dilation_rate,
|
| 399 |
+
n_layers,
|
| 400 |
+
n_flows=4,
|
| 401 |
+
gin_channels=0,
|
| 402 |
+
):
|
| 403 |
+
super().__init__()
|
| 404 |
+
self.channels = channels
|
| 405 |
+
self.hidden_channels = hidden_channels
|
| 406 |
+
self.kernel_size = kernel_size
|
| 407 |
+
self.dilation_rate = dilation_rate
|
| 408 |
+
self.n_layers = n_layers
|
| 409 |
+
self.n_flows = n_flows
|
| 410 |
+
self.gin_channels = gin_channels
|
| 411 |
+
|
| 412 |
+
self.flows = nn.ModuleList()
|
| 413 |
+
for i in range(n_flows):
|
| 414 |
+
self.flows.append(
|
| 415 |
+
modules.ResidualCouplingLayer(
|
| 416 |
+
channels,
|
| 417 |
+
hidden_channels,
|
| 418 |
+
kernel_size,
|
| 419 |
+
dilation_rate,
|
| 420 |
+
n_layers,
|
| 421 |
+
gin_channels=gin_channels,
|
| 422 |
+
mean_only=True,
|
| 423 |
+
)
|
| 424 |
+
)
|
| 425 |
+
self.flows.append(modules.Flip())
|
| 426 |
+
|
| 427 |
+
def forward(self, x, x_mask, g=None, reverse=True):
|
| 428 |
+
if not reverse:
|
| 429 |
+
for flow in self.flows:
|
| 430 |
+
x, _ = flow(x, x_mask, g=g, reverse=reverse)
|
| 431 |
+
else:
|
| 432 |
+
for flow in reversed(self.flows):
|
| 433 |
+
x = flow(x, x_mask, g=g, reverse=reverse)
|
| 434 |
+
return x
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
class PosteriorEncoder(nn.Module):
|
| 438 |
+
def __init__(
|
| 439 |
+
self,
|
| 440 |
+
in_channels,
|
| 441 |
+
out_channels,
|
| 442 |
+
hidden_channels,
|
| 443 |
+
kernel_size,
|
| 444 |
+
dilation_rate,
|
| 445 |
+
n_layers,
|
| 446 |
+
gin_channels=0,
|
| 447 |
+
):
|
| 448 |
+
super().__init__()
|
| 449 |
+
self.in_channels = in_channels
|
| 450 |
+
self.out_channels = out_channels
|
| 451 |
+
self.hidden_channels = hidden_channels
|
| 452 |
+
self.kernel_size = kernel_size
|
| 453 |
+
self.dilation_rate = dilation_rate
|
| 454 |
+
self.n_layers = n_layers
|
| 455 |
+
self.gin_channels = gin_channels
|
| 456 |
+
|
| 457 |
+
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
| 458 |
+
self.enc = modules.WN(
|
| 459 |
+
hidden_channels,
|
| 460 |
+
kernel_size,
|
| 461 |
+
dilation_rate,
|
| 462 |
+
n_layers,
|
| 463 |
+
gin_channels=gin_channels,
|
| 464 |
+
)
|
| 465 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
| 466 |
+
|
| 467 |
+
def forward(self, x, x_lengths, g=None):
|
| 468 |
+
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
|
| 469 |
+
x.dtype
|
| 470 |
+
)
|
| 471 |
+
x = self.pre(x) * x_mask
|
| 472 |
+
x = self.enc(x, x_mask, g=g)
|
| 473 |
+
stats = self.proj(x) * x_mask
|
| 474 |
+
m, logs = torch.split(stats, self.out_channels, dim=1)
|
| 475 |
+
z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
|
| 476 |
+
return z, m, logs, x_mask
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
class Generator(torch.nn.Module):
|
| 480 |
+
def __init__(
|
| 481 |
+
self,
|
| 482 |
+
initial_channel,
|
| 483 |
+
resblock,
|
| 484 |
+
resblock_kernel_sizes,
|
| 485 |
+
resblock_dilation_sizes,
|
| 486 |
+
upsample_rates,
|
| 487 |
+
upsample_initial_channel,
|
| 488 |
+
upsample_kernel_sizes,
|
| 489 |
+
gin_channels=0,
|
| 490 |
+
):
|
| 491 |
+
super(Generator, self).__init__()
|
| 492 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
| 493 |
+
self.num_upsamples = len(upsample_rates)
|
| 494 |
+
self.conv_pre = Conv1d(
|
| 495 |
+
initial_channel, upsample_initial_channel, 7, 1, padding=3
|
| 496 |
+
)
|
| 497 |
+
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
|
| 498 |
+
|
| 499 |
+
self.ups = nn.ModuleList()
|
| 500 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
| 501 |
+
self.ups.append(
|
| 502 |
+
weight_norm(
|
| 503 |
+
ConvTranspose1d(
|
| 504 |
+
upsample_initial_channel // (2**i),
|
| 505 |
+
upsample_initial_channel // (2 ** (i + 1)),
|
| 506 |
+
k,
|
| 507 |
+
u,
|
| 508 |
+
padding=(k - u) // 2,
|
| 509 |
+
)
|
| 510 |
+
)
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
self.resblocks = nn.ModuleList()
|
| 514 |
+
for i in range(len(self.ups)):
|
| 515 |
+
ch = upsample_initial_channel // (2 ** (i + 1))
|
| 516 |
+
for j, (k, d) in enumerate(
|
| 517 |
+
zip(resblock_kernel_sizes, resblock_dilation_sizes)
|
| 518 |
+
):
|
| 519 |
+
self.resblocks.append(resblock(ch, k, d))
|
| 520 |
+
|
| 521 |
+
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
| 522 |
+
self.ups.apply(init_weights)
|
| 523 |
+
|
| 524 |
+
if gin_channels != 0:
|
| 525 |
+
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
| 526 |
+
|
| 527 |
+
def forward(self, x, g=None):
|
| 528 |
+
x = self.conv_pre(x)
|
| 529 |
+
if g is not None:
|
| 530 |
+
x = x + self.cond(g)
|
| 531 |
+
|
| 532 |
+
for i in range(self.num_upsamples):
|
| 533 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
| 534 |
+
x = self.ups[i](x)
|
| 535 |
+
xs = None
|
| 536 |
+
for j in range(self.num_kernels):
|
| 537 |
+
if xs is None:
|
| 538 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
| 539 |
+
else:
|
| 540 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
| 541 |
+
x = xs / self.num_kernels
|
| 542 |
+
x = F.leaky_relu(x)
|
| 543 |
+
x = self.conv_post(x)
|
| 544 |
+
x = torch.tanh(x)
|
| 545 |
+
|
| 546 |
+
return x
|
| 547 |
+
|
| 548 |
+
def remove_weight_norm(self):
|
| 549 |
+
print("Removing weight norm...")
|
| 550 |
+
for layer in self.ups:
|
| 551 |
+
remove_weight_norm(layer)
|
| 552 |
+
for layer in self.resblocks:
|
| 553 |
+
layer.remove_weight_norm()
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
class DiscriminatorP(torch.nn.Module):
|
| 557 |
+
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
| 558 |
+
super(DiscriminatorP, self).__init__()
|
| 559 |
+
self.period = period
|
| 560 |
+
self.use_spectral_norm = use_spectral_norm
|
| 561 |
+
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
|
| 562 |
+
self.convs = nn.ModuleList(
|
| 563 |
+
[
|
| 564 |
+
norm_f(
|
| 565 |
+
Conv2d(
|
| 566 |
+
1,
|
| 567 |
+
32,
|
| 568 |
+
(kernel_size, 1),
|
| 569 |
+
(stride, 1),
|
| 570 |
+
padding=(get_padding(kernel_size, 1), 0),
|
| 571 |
+
)
|
| 572 |
+
),
|
| 573 |
+
norm_f(
|
| 574 |
+
Conv2d(
|
| 575 |
+
32,
|
| 576 |
+
128,
|
| 577 |
+
(kernel_size, 1),
|
| 578 |
+
(stride, 1),
|
| 579 |
+
padding=(get_padding(kernel_size, 1), 0),
|
| 580 |
+
)
|
| 581 |
+
),
|
| 582 |
+
norm_f(
|
| 583 |
+
Conv2d(
|
| 584 |
+
128,
|
| 585 |
+
512,
|
| 586 |
+
(kernel_size, 1),
|
| 587 |
+
(stride, 1),
|
| 588 |
+
padding=(get_padding(kernel_size, 1), 0),
|
| 589 |
+
)
|
| 590 |
+
),
|
| 591 |
+
norm_f(
|
| 592 |
+
Conv2d(
|
| 593 |
+
512,
|
| 594 |
+
1024,
|
| 595 |
+
(kernel_size, 1),
|
| 596 |
+
(stride, 1),
|
| 597 |
+
padding=(get_padding(kernel_size, 1), 0),
|
| 598 |
+
)
|
| 599 |
+
),
|
| 600 |
+
norm_f(
|
| 601 |
+
Conv2d(
|
| 602 |
+
1024,
|
| 603 |
+
1024,
|
| 604 |
+
(kernel_size, 1),
|
| 605 |
+
1,
|
| 606 |
+
padding=(get_padding(kernel_size, 1), 0),
|
| 607 |
+
)
|
| 608 |
+
),
|
| 609 |
+
]
|
| 610 |
+
)
|
| 611 |
+
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
| 612 |
+
|
| 613 |
+
def forward(self, x):
|
| 614 |
+
fmap = []
|
| 615 |
+
|
| 616 |
+
# 1d to 2d
|
| 617 |
+
b, c, t = x.shape
|
| 618 |
+
if t % self.period != 0: # pad first
|
| 619 |
+
n_pad = self.period - (t % self.period)
|
| 620 |
+
x = F.pad(x, (0, n_pad), "reflect")
|
| 621 |
+
t = t + n_pad
|
| 622 |
+
x = x.view(b, c, t // self.period, self.period)
|
| 623 |
+
|
| 624 |
+
for layer in self.convs:
|
| 625 |
+
x = layer(x)
|
| 626 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
| 627 |
+
fmap.append(x)
|
| 628 |
+
x = self.conv_post(x)
|
| 629 |
+
fmap.append(x)
|
| 630 |
+
x = torch.flatten(x, 1, -1)
|
| 631 |
+
|
| 632 |
+
return x, fmap
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
class DiscriminatorS(torch.nn.Module):
|
| 636 |
+
def __init__(self, use_spectral_norm=False):
|
| 637 |
+
super(DiscriminatorS, self).__init__()
|
| 638 |
+
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
|
| 639 |
+
self.convs = nn.ModuleList(
|
| 640 |
+
[
|
| 641 |
+
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
|
| 642 |
+
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
|
| 643 |
+
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
|
| 644 |
+
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
|
| 645 |
+
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
|
| 646 |
+
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
| 647 |
+
]
|
| 648 |
+
)
|
| 649 |
+
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
| 650 |
+
|
| 651 |
+
def forward(self, x):
|
| 652 |
+
fmap = []
|
| 653 |
+
|
| 654 |
+
for layer in self.convs:
|
| 655 |
+
x = layer(x)
|
| 656 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
| 657 |
+
fmap.append(x)
|
| 658 |
+
x = self.conv_post(x)
|
| 659 |
+
fmap.append(x)
|
| 660 |
+
x = torch.flatten(x, 1, -1)
|
| 661 |
+
|
| 662 |
+
return x, fmap
|
| 663 |
+
|
| 664 |
+
|
| 665 |
+
class MultiPeriodDiscriminator(torch.nn.Module):
|
| 666 |
+
def __init__(self, use_spectral_norm=False):
|
| 667 |
+
super(MultiPeriodDiscriminator, self).__init__()
|
| 668 |
+
periods = [2, 3, 5, 7, 11]
|
| 669 |
+
|
| 670 |
+
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
|
| 671 |
+
discs = discs + [
|
| 672 |
+
DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
|
| 673 |
+
]
|
| 674 |
+
self.discriminators = nn.ModuleList(discs)
|
| 675 |
+
|
| 676 |
+
def forward(self, y, y_hat):
|
| 677 |
+
y_d_rs = []
|
| 678 |
+
y_d_gs = []
|
| 679 |
+
fmap_rs = []
|
| 680 |
+
fmap_gs = []
|
| 681 |
+
for i, d in enumerate(self.discriminators):
|
| 682 |
+
y_d_r, fmap_r = d(y)
|
| 683 |
+
y_d_g, fmap_g = d(y_hat)
|
| 684 |
+
y_d_rs.append(y_d_r)
|
| 685 |
+
y_d_gs.append(y_d_g)
|
| 686 |
+
fmap_rs.append(fmap_r)
|
| 687 |
+
fmap_gs.append(fmap_g)
|
| 688 |
+
|
| 689 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
| 690 |
+
|
| 691 |
+
|
| 692 |
+
class ReferenceEncoder(nn.Module):
|
| 693 |
+
"""
|
| 694 |
+
inputs --- [N, Ty/r, n_mels*r] mels
|
| 695 |
+
outputs --- [N, ref_enc_gru_size]
|
| 696 |
+
"""
|
| 697 |
+
|
| 698 |
+
def __init__(self, spec_channels, gin_channels=0):
|
| 699 |
+
super().__init__()
|
| 700 |
+
self.spec_channels = spec_channels
|
| 701 |
+
ref_enc_filters = [32, 32, 64, 64, 128, 128]
|
| 702 |
+
K = len(ref_enc_filters)
|
| 703 |
+
filters = [1] + ref_enc_filters
|
| 704 |
+
convs = [
|
| 705 |
+
weight_norm(
|
| 706 |
+
nn.Conv2d(
|
| 707 |
+
in_channels=filters[i],
|
| 708 |
+
out_channels=filters[i + 1],
|
| 709 |
+
kernel_size=(3, 3),
|
| 710 |
+
stride=(2, 2),
|
| 711 |
+
padding=(1, 1),
|
| 712 |
+
)
|
| 713 |
+
)
|
| 714 |
+
for i in range(K)
|
| 715 |
+
]
|
| 716 |
+
self.convs = nn.ModuleList(convs)
|
| 717 |
+
# self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)]) # noqa: E501
|
| 718 |
+
|
| 719 |
+
out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
|
| 720 |
+
self.gru = nn.GRU(
|
| 721 |
+
input_size=ref_enc_filters[-1] * out_channels,
|
| 722 |
+
hidden_size=256 // 2,
|
| 723 |
+
batch_first=True,
|
| 724 |
+
)
|
| 725 |
+
self.proj = nn.Linear(128, gin_channels)
|
| 726 |
+
|
| 727 |
+
def forward(self, inputs, mask=None):
|
| 728 |
+
N = inputs.size(0)
|
| 729 |
+
out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
|
| 730 |
+
for conv in self.convs:
|
| 731 |
+
out = conv(out)
|
| 732 |
+
# out = wn(out)
|
| 733 |
+
out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
|
| 734 |
+
|
| 735 |
+
out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
|
| 736 |
+
T = out.size(1)
|
| 737 |
+
N = out.size(0)
|
| 738 |
+
out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
|
| 739 |
+
|
| 740 |
+
self.gru.flatten_parameters()
|
| 741 |
+
memory, out = self.gru(out) # out --- [1, N, 128]
|
| 742 |
+
|
| 743 |
+
return self.proj(out.squeeze(0))
|
| 744 |
+
|
| 745 |
+
def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
|
| 746 |
+
for i in range(n_convs):
|
| 747 |
+
L = (L - kernel_size + 2 * pad) // stride + 1
|
| 748 |
+
return L
|
| 749 |
+
|
| 750 |
+
|
| 751 |
+
class SynthesizerTrn(nn.Module):
|
| 752 |
+
"""
|
| 753 |
+
Synthesizer for Training
|
| 754 |
+
"""
|
| 755 |
+
|
| 756 |
+
def __init__(
|
| 757 |
+
self,
|
| 758 |
+
n_vocab,
|
| 759 |
+
spec_channels,
|
| 760 |
+
segment_size,
|
| 761 |
+
inter_channels,
|
| 762 |
+
hidden_channels,
|
| 763 |
+
filter_channels,
|
| 764 |
+
n_heads,
|
| 765 |
+
n_layers,
|
| 766 |
+
kernel_size,
|
| 767 |
+
p_dropout,
|
| 768 |
+
resblock,
|
| 769 |
+
resblock_kernel_sizes,
|
| 770 |
+
resblock_dilation_sizes,
|
| 771 |
+
upsample_rates,
|
| 772 |
+
upsample_initial_channel,
|
| 773 |
+
upsample_kernel_sizes,
|
| 774 |
+
n_speakers=256,
|
| 775 |
+
gin_channels=256,
|
| 776 |
+
use_sdp=True,
|
| 777 |
+
n_flow_layer=4,
|
| 778 |
+
n_layers_trans_flow=4,
|
| 779 |
+
flow_share_parameter=False,
|
| 780 |
+
use_transformer_flow=True,
|
| 781 |
+
**kwargs,
|
| 782 |
+
):
|
| 783 |
+
super().__init__()
|
| 784 |
+
self.n_vocab = n_vocab
|
| 785 |
+
self.spec_channels = spec_channels
|
| 786 |
+
self.inter_channels = inter_channels
|
| 787 |
+
self.hidden_channels = hidden_channels
|
| 788 |
+
self.filter_channels = filter_channels
|
| 789 |
+
self.n_heads = n_heads
|
| 790 |
+
self.n_layers = n_layers
|
| 791 |
+
self.kernel_size = kernel_size
|
| 792 |
+
self.p_dropout = p_dropout
|
| 793 |
+
self.resblock = resblock
|
| 794 |
+
self.resblock_kernel_sizes = resblock_kernel_sizes
|
| 795 |
+
self.resblock_dilation_sizes = resblock_dilation_sizes
|
| 796 |
+
self.upsample_rates = upsample_rates
|
| 797 |
+
self.upsample_initial_channel = upsample_initial_channel
|
| 798 |
+
self.upsample_kernel_sizes = upsample_kernel_sizes
|
| 799 |
+
self.segment_size = segment_size
|
| 800 |
+
self.n_speakers = n_speakers
|
| 801 |
+
self.gin_channels = gin_channels
|
| 802 |
+
self.n_layers_trans_flow = n_layers_trans_flow
|
| 803 |
+
self.use_spk_conditioned_encoder = kwargs.get(
|
| 804 |
+
"use_spk_conditioned_encoder", True
|
| 805 |
+
)
|
| 806 |
+
self.use_sdp = use_sdp
|
| 807 |
+
self.use_noise_scaled_mas = kwargs.get("use_noise_scaled_mas", False)
|
| 808 |
+
self.mas_noise_scale_initial = kwargs.get("mas_noise_scale_initial", 0.01)
|
| 809 |
+
self.noise_scale_delta = kwargs.get("noise_scale_delta", 2e-6)
|
| 810 |
+
self.current_mas_noise_scale = self.mas_noise_scale_initial
|
| 811 |
+
if self.use_spk_conditioned_encoder and gin_channels > 0:
|
| 812 |
+
self.enc_gin_channels = gin_channels
|
| 813 |
+
self.enc_p = TextEncoder(
|
| 814 |
+
n_vocab,
|
| 815 |
+
inter_channels,
|
| 816 |
+
hidden_channels,
|
| 817 |
+
filter_channels,
|
| 818 |
+
n_heads,
|
| 819 |
+
n_layers,
|
| 820 |
+
kernel_size,
|
| 821 |
+
p_dropout,
|
| 822 |
+
self.n_speakers,
|
| 823 |
+
gin_channels=self.enc_gin_channels,
|
| 824 |
+
)
|
| 825 |
+
self.dec = Generator(
|
| 826 |
+
inter_channels,
|
| 827 |
+
resblock,
|
| 828 |
+
resblock_kernel_sizes,
|
| 829 |
+
resblock_dilation_sizes,
|
| 830 |
+
upsample_rates,
|
| 831 |
+
upsample_initial_channel,
|
| 832 |
+
upsample_kernel_sizes,
|
| 833 |
+
gin_channels=gin_channels,
|
| 834 |
+
)
|
| 835 |
+
self.enc_q = PosteriorEncoder(
|
| 836 |
+
spec_channels,
|
| 837 |
+
inter_channels,
|
| 838 |
+
hidden_channels,
|
| 839 |
+
5,
|
| 840 |
+
1,
|
| 841 |
+
16,
|
| 842 |
+
gin_channels=gin_channels,
|
| 843 |
+
)
|
| 844 |
+
if use_transformer_flow:
|
| 845 |
+
self.flow = TransformerCouplingBlock(
|
| 846 |
+
inter_channels,
|
| 847 |
+
hidden_channels,
|
| 848 |
+
filter_channels,
|
| 849 |
+
n_heads,
|
| 850 |
+
n_layers_trans_flow,
|
| 851 |
+
5,
|
| 852 |
+
p_dropout,
|
| 853 |
+
n_flow_layer,
|
| 854 |
+
gin_channels=gin_channels,
|
| 855 |
+
share_parameter=flow_share_parameter,
|
| 856 |
+
)
|
| 857 |
+
else:
|
| 858 |
+
self.flow = ResidualCouplingBlock(
|
| 859 |
+
inter_channels,
|
| 860 |
+
hidden_channels,
|
| 861 |
+
5,
|
| 862 |
+
1,
|
| 863 |
+
n_flow_layer,
|
| 864 |
+
gin_channels=gin_channels,
|
| 865 |
+
)
|
| 866 |
+
self.sdp = StochasticDurationPredictor(
|
| 867 |
+
hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels
|
| 868 |
+
)
|
| 869 |
+
self.dp = DurationPredictor(
|
| 870 |
+
hidden_channels, 256, 3, 0.5, gin_channels=gin_channels
|
| 871 |
+
)
|
| 872 |
+
|
| 873 |
+
if n_speakers >= 1:
|
| 874 |
+
self.emb_g = nn.Embedding(n_speakers, gin_channels)
|
| 875 |
+
else:
|
| 876 |
+
self.ref_enc = ReferenceEncoder(spec_channels, gin_channels)
|
| 877 |
+
|
| 878 |
+
def export_onnx(
|
| 879 |
+
self,
|
| 880 |
+
path,
|
| 881 |
+
max_len=None,
|
| 882 |
+
sdp_ratio=0,
|
| 883 |
+
y=None,
|
| 884 |
+
):
|
| 885 |
+
noise_scale = 0.667
|
| 886 |
+
length_scale = 1
|
| 887 |
+
noise_scale_w = 0.8
|
| 888 |
+
x = (
|
| 889 |
+
torch.LongTensor(
|
| 890 |
+
[
|
| 891 |
+
0,
|
| 892 |
+
97,
|
| 893 |
+
0,
|
| 894 |
+
8,
|
| 895 |
+
0,
|
| 896 |
+
78,
|
| 897 |
+
0,
|
| 898 |
+
8,
|
| 899 |
+
0,
|
| 900 |
+
76,
|
| 901 |
+
0,
|
| 902 |
+
37,
|
| 903 |
+
0,
|
| 904 |
+
40,
|
| 905 |
+
0,
|
| 906 |
+
97,
|
| 907 |
+
0,
|
| 908 |
+
8,
|
| 909 |
+
0,
|
| 910 |
+
23,
|
| 911 |
+
0,
|
| 912 |
+
8,
|
| 913 |
+
0,
|
| 914 |
+
74,
|
| 915 |
+
0,
|
| 916 |
+
26,
|
| 917 |
+
0,
|
| 918 |
+
104,
|
| 919 |
+
0,
|
| 920 |
+
]
|
| 921 |
+
)
|
| 922 |
+
.unsqueeze(0)
|
| 923 |
+
.cpu()
|
| 924 |
+
)
|
| 925 |
+
tone = torch.zeros_like(x).cpu()
|
| 926 |
+
language = torch.zeros_like(x).cpu()
|
| 927 |
+
x_lengths = torch.LongTensor([x.shape[1]]).cpu()
|
| 928 |
+
sid = torch.LongTensor([0]).cpu()
|
| 929 |
+
bert = torch.randn(size=(x.shape[1], 1024)).cpu()
|
| 930 |
+
ja_bert = torch.randn(size=(x.shape[1], 1024)).cpu()
|
| 931 |
+
en_bert = torch.randn(size=(x.shape[1], 1024)).cpu()
|
| 932 |
+
|
| 933 |
+
if self.n_speakers > 0:
|
| 934 |
+
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
| 935 |
+
torch.onnx.export(
|
| 936 |
+
self.emb_g,
|
| 937 |
+
(sid),
|
| 938 |
+
f"onnx/{path}/{path}_emb.onnx",
|
| 939 |
+
input_names=["sid"],
|
| 940 |
+
output_names=["g"],
|
| 941 |
+
verbose=True,
|
| 942 |
+
)
|
| 943 |
+
else:
|
| 944 |
+
g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
|
| 945 |
+
|
| 946 |
+
emo = torch.randn(512, 1)
|
| 947 |
+
|
| 948 |
+
torch.onnx.export(
|
| 949 |
+
self.enc_p,
|
| 950 |
+
(x, x_lengths, tone, language, bert, ja_bert, en_bert, emo, g),
|
| 951 |
+
f"onnx/{path}/{path}_enc_p.onnx",
|
| 952 |
+
input_names=[
|
| 953 |
+
"x",
|
| 954 |
+
"x_lengths",
|
| 955 |
+
"t",
|
| 956 |
+
"language",
|
| 957 |
+
"bert_0",
|
| 958 |
+
"bert_1",
|
| 959 |
+
"bert_2",
|
| 960 |
+
"emo",
|
| 961 |
+
"g",
|
| 962 |
+
],
|
| 963 |
+
output_names=["xout", "m_p", "logs_p", "x_mask"],
|
| 964 |
+
dynamic_axes={
|
| 965 |
+
"x": [0, 1],
|
| 966 |
+
"t": [0, 1],
|
| 967 |
+
"language": [0, 1],
|
| 968 |
+
"bert_0": [0],
|
| 969 |
+
"bert_1": [0],
|
| 970 |
+
"bert_2": [0],
|
| 971 |
+
"xout": [0, 2],
|
| 972 |
+
"m_p": [0, 2],
|
| 973 |
+
"logs_p": [0, 2],
|
| 974 |
+
"x_mask": [0, 2],
|
| 975 |
+
},
|
| 976 |
+
verbose=True,
|
| 977 |
+
opset_version=16,
|
| 978 |
+
)
|
| 979 |
+
|
| 980 |
+
x, m_p, logs_p, x_mask = self.enc_p(
|
| 981 |
+
x, x_lengths, tone, language, bert, ja_bert, en_bert, emo, g
|
| 982 |
+
)
|
| 983 |
+
|
| 984 |
+
zinput = (
|
| 985 |
+
torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
|
| 986 |
+
* noise_scale_w
|
| 987 |
+
)
|
| 988 |
+
torch.onnx.export(
|
| 989 |
+
self.sdp,
|
| 990 |
+
(x, x_mask, zinput, g),
|
| 991 |
+
f"onnx/{path}/{path}_sdp.onnx",
|
| 992 |
+
input_names=["x", "x_mask", "zin", "g"],
|
| 993 |
+
output_names=["logw"],
|
| 994 |
+
dynamic_axes={"x": [0, 2], "x_mask": [0, 2], "zin": [0, 2], "logw": [0, 2]},
|
| 995 |
+
verbose=True,
|
| 996 |
+
)
|
| 997 |
+
torch.onnx.export(
|
| 998 |
+
self.dp,
|
| 999 |
+
(x, x_mask, g),
|
| 1000 |
+
f"onnx/{path}/{path}_dp.onnx",
|
| 1001 |
+
input_names=["x", "x_mask", "g"],
|
| 1002 |
+
output_names=["logw"],
|
| 1003 |
+
dynamic_axes={"x": [0, 2], "x_mask": [0, 2], "logw": [0, 2]},
|
| 1004 |
+
verbose=True,
|
| 1005 |
+
)
|
| 1006 |
+
logw = self.sdp(x, x_mask, zinput, g=g) * (sdp_ratio) + self.dp(
|
| 1007 |
+
x, x_mask, g=g
|
| 1008 |
+
) * (1 - sdp_ratio)
|
| 1009 |
+
w = torch.exp(logw) * x_mask * length_scale
|
| 1010 |
+
w_ceil = torch.ceil(w)
|
| 1011 |
+
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
| 1012 |
+
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(
|
| 1013 |
+
x_mask.dtype
|
| 1014 |
+
)
|
| 1015 |
+
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
|
| 1016 |
+
attn = commons.generate_path(w_ceil, attn_mask)
|
| 1017 |
+
|
| 1018 |
+
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(
|
| 1019 |
+
1, 2
|
| 1020 |
+
) # [b, t', t], [b, t, d] -> [b, d, t']
|
| 1021 |
+
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(
|
| 1022 |
+
1, 2
|
| 1023 |
+
) # [b, t', t], [b, t, d] -> [b, d, t']
|
| 1024 |
+
|
| 1025 |
+
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
| 1026 |
+
torch.onnx.export(
|
| 1027 |
+
self.flow,
|
| 1028 |
+
(z_p, y_mask, g),
|
| 1029 |
+
f"onnx/{path}/{path}_flow.onnx",
|
| 1030 |
+
input_names=["z_p", "y_mask", "g"],
|
| 1031 |
+
output_names=["z"],
|
| 1032 |
+
dynamic_axes={"z_p": [0, 2], "y_mask": [0, 2], "z": [0, 2]},
|
| 1033 |
+
verbose=True,
|
| 1034 |
+
)
|
| 1035 |
+
|
| 1036 |
+
z = self.flow(z_p, y_mask, g=g, reverse=True)
|
| 1037 |
+
z_in = (z * y_mask)[:, :, :max_len]
|
| 1038 |
+
|
| 1039 |
+
torch.onnx.export(
|
| 1040 |
+
self.dec,
|
| 1041 |
+
(z_in, g),
|
| 1042 |
+
f"onnx/{path}/{path}_dec.onnx",
|
| 1043 |
+
input_names=["z_in", "g"],
|
| 1044 |
+
output_names=["o"],
|
| 1045 |
+
dynamic_axes={"z_in": [0, 2], "o": [0, 2]},
|
| 1046 |
+
verbose=True,
|
| 1047 |
+
)
|
| 1048 |
+
o = self.dec((z * y_mask)[:, :, :max_len], g=g)
|
onnx_modules/V220_novq_dev/text/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .symbols import *
|
onnx_modules/V220_novq_dev/text/symbols.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
punctuation = ["!", "?", "…", ",", ".", "'", "-"]
|
| 2 |
+
pu_symbols = punctuation + ["SP", "UNK"]
|
| 3 |
+
pad = "_"
|
| 4 |
+
|
| 5 |
+
# chinese
|
| 6 |
+
zh_symbols = [
|
| 7 |
+
"E",
|
| 8 |
+
"En",
|
| 9 |
+
"a",
|
| 10 |
+
"ai",
|
| 11 |
+
"an",
|
| 12 |
+
"ang",
|
| 13 |
+
"ao",
|
| 14 |
+
"b",
|
| 15 |
+
"c",
|
| 16 |
+
"ch",
|
| 17 |
+
"d",
|
| 18 |
+
"e",
|
| 19 |
+
"ei",
|
| 20 |
+
"en",
|
| 21 |
+
"eng",
|
| 22 |
+
"er",
|
| 23 |
+
"f",
|
| 24 |
+
"g",
|
| 25 |
+
"h",
|
| 26 |
+
"i",
|
| 27 |
+
"i0",
|
| 28 |
+
"ia",
|
| 29 |
+
"ian",
|
| 30 |
+
"iang",
|
| 31 |
+
"iao",
|
| 32 |
+
"ie",
|
| 33 |
+
"in",
|
| 34 |
+
"ing",
|
| 35 |
+
"iong",
|
| 36 |
+
"ir",
|
| 37 |
+
"iu",
|
| 38 |
+
"j",
|
| 39 |
+
"k",
|
| 40 |
+
"l",
|
| 41 |
+
"m",
|
| 42 |
+
"n",
|
| 43 |
+
"o",
|
| 44 |
+
"ong",
|
| 45 |
+
"ou",
|
| 46 |
+
"p",
|
| 47 |
+
"q",
|
| 48 |
+
"r",
|
| 49 |
+
"s",
|
| 50 |
+
"sh",
|
| 51 |
+
"t",
|
| 52 |
+
"u",
|
| 53 |
+
"ua",
|
| 54 |
+
"uai",
|
| 55 |
+
"uan",
|
| 56 |
+
"uang",
|
| 57 |
+
"ui",
|
| 58 |
+
"un",
|
| 59 |
+
"uo",
|
| 60 |
+
"v",
|
| 61 |
+
"van",
|
| 62 |
+
"ve",
|
| 63 |
+
"vn",
|
| 64 |
+
"w",
|
| 65 |
+
"x",
|
| 66 |
+
"y",
|
| 67 |
+
"z",
|
| 68 |
+
"zh",
|
| 69 |
+
"AA",
|
| 70 |
+
"EE",
|
| 71 |
+
"OO",
|
| 72 |
+
]
|
| 73 |
+
num_zh_tones = 6
|
| 74 |
+
|
| 75 |
+
# japanese
|
| 76 |
+
ja_symbols = [
|
| 77 |
+
"N",
|
| 78 |
+
"a",
|
| 79 |
+
"a:",
|
| 80 |
+
"b",
|
| 81 |
+
"by",
|
| 82 |
+
"ch",
|
| 83 |
+
"d",
|
| 84 |
+
"dy",
|
| 85 |
+
"e",
|
| 86 |
+
"e:",
|
| 87 |
+
"f",
|
| 88 |
+
"g",
|
| 89 |
+
"gy",
|
| 90 |
+
"h",
|
| 91 |
+
"hy",
|
| 92 |
+
"i",
|
| 93 |
+
"i:",
|
| 94 |
+
"j",
|
| 95 |
+
"k",
|
| 96 |
+
"ky",
|
| 97 |
+
"m",
|
| 98 |
+
"my",
|
| 99 |
+
"n",
|
| 100 |
+
"ny",
|
| 101 |
+
"o",
|
| 102 |
+
"o:",
|
| 103 |
+
"p",
|
| 104 |
+
"py",
|
| 105 |
+
"q",
|
| 106 |
+
"r",
|
| 107 |
+
"ry",
|
| 108 |
+
"s",
|
| 109 |
+
"sh",
|
| 110 |
+
"t",
|
| 111 |
+
"ts",
|
| 112 |
+
"ty",
|
| 113 |
+
"u",
|
| 114 |
+
"u:",
|
| 115 |
+
"w",
|
| 116 |
+
"y",
|
| 117 |
+
"z",
|
| 118 |
+
"zy",
|
| 119 |
+
]
|
| 120 |
+
num_ja_tones = 2
|
| 121 |
+
|
| 122 |
+
# English
|
| 123 |
+
en_symbols = [
|
| 124 |
+
"aa",
|
| 125 |
+
"ae",
|
| 126 |
+
"ah",
|
| 127 |
+
"ao",
|
| 128 |
+
"aw",
|
| 129 |
+
"ay",
|
| 130 |
+
"b",
|
| 131 |
+
"ch",
|
| 132 |
+
"d",
|
| 133 |
+
"dh",
|
| 134 |
+
"eh",
|
| 135 |
+
"er",
|
| 136 |
+
"ey",
|
| 137 |
+
"f",
|
| 138 |
+
"g",
|
| 139 |
+
"hh",
|
| 140 |
+
"ih",
|
| 141 |
+
"iy",
|
| 142 |
+
"jh",
|
| 143 |
+
"k",
|
| 144 |
+
"l",
|
| 145 |
+
"m",
|
| 146 |
+
"n",
|
| 147 |
+
"ng",
|
| 148 |
+
"ow",
|
| 149 |
+
"oy",
|
| 150 |
+
"p",
|
| 151 |
+
"r",
|
| 152 |
+
"s",
|
| 153 |
+
"sh",
|
| 154 |
+
"t",
|
| 155 |
+
"th",
|
| 156 |
+
"uh",
|
| 157 |
+
"uw",
|
| 158 |
+
"V",
|
| 159 |
+
"w",
|
| 160 |
+
"y",
|
| 161 |
+
"z",
|
| 162 |
+
"zh",
|
| 163 |
+
]
|
| 164 |
+
num_en_tones = 4
|
| 165 |
+
|
| 166 |
+
# combine all symbols
|
| 167 |
+
normal_symbols = sorted(set(zh_symbols + ja_symbols + en_symbols))
|
| 168 |
+
symbols = [pad] + normal_symbols + pu_symbols
|
| 169 |
+
sil_phonemes_ids = [symbols.index(i) for i in pu_symbols]
|
| 170 |
+
|
| 171 |
+
# combine all tones
|
| 172 |
+
num_tones = num_zh_tones + num_ja_tones + num_en_tones
|
| 173 |
+
|
| 174 |
+
# language maps
|
| 175 |
+
language_id_map = {"ZH": 0, "JP": 1, "EN": 2}
|
| 176 |
+
num_languages = len(language_id_map.keys())
|
| 177 |
+
|
| 178 |
+
language_tone_start_map = {
|
| 179 |
+
"ZH": 0,
|
| 180 |
+
"JP": num_zh_tones,
|
| 181 |
+
"EN": num_zh_tones + num_ja_tones,
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
if __name__ == "__main__":
|
| 185 |
+
a = set(zh_symbols)
|
| 186 |
+
b = set(en_symbols)
|
| 187 |
+
print(sorted(a & b))
|
onnx_modules/V230/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .text.symbols import symbols
|
| 2 |
+
from .models_onnx import SynthesizerTrn
|
| 3 |
+
|
| 4 |
+
__all__ = ["symbols", "SynthesizerTrn"]
|
onnx_modules/V230/attentions_onnx.py
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
from torch.nn import functional as F
|
| 5 |
+
|
| 6 |
+
import commons
|
| 7 |
+
import logging
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class LayerNorm(nn.Module):
|
| 13 |
+
def __init__(self, channels, eps=1e-5):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.channels = channels
|
| 16 |
+
self.eps = eps
|
| 17 |
+
|
| 18 |
+
self.gamma = nn.Parameter(torch.ones(channels))
|
| 19 |
+
self.beta = nn.Parameter(torch.zeros(channels))
|
| 20 |
+
|
| 21 |
+
def forward(self, x):
|
| 22 |
+
x = x.transpose(1, -1)
|
| 23 |
+
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
| 24 |
+
return x.transpose(1, -1)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@torch.jit.script
|
| 28 |
+
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
| 29 |
+
n_channels_int = n_channels[0]
|
| 30 |
+
in_act = input_a + input_b
|
| 31 |
+
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
| 32 |
+
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
| 33 |
+
acts = t_act * s_act
|
| 34 |
+
return acts
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class Encoder(nn.Module):
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
hidden_channels,
|
| 41 |
+
filter_channels,
|
| 42 |
+
n_heads,
|
| 43 |
+
n_layers,
|
| 44 |
+
kernel_size=1,
|
| 45 |
+
p_dropout=0.0,
|
| 46 |
+
window_size=4,
|
| 47 |
+
isflow=True,
|
| 48 |
+
**kwargs
|
| 49 |
+
):
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.hidden_channels = hidden_channels
|
| 52 |
+
self.filter_channels = filter_channels
|
| 53 |
+
self.n_heads = n_heads
|
| 54 |
+
self.n_layers = n_layers
|
| 55 |
+
self.kernel_size = kernel_size
|
| 56 |
+
self.p_dropout = p_dropout
|
| 57 |
+
self.window_size = window_size
|
| 58 |
+
# if isflow:
|
| 59 |
+
# cond_layer = torch.nn.Conv1d(256, 2*hidden_channels*n_layers, 1)
|
| 60 |
+
# self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1)
|
| 61 |
+
# self.cond_layer = weight_norm(cond_layer, name='weight')
|
| 62 |
+
# self.gin_channels = 256
|
| 63 |
+
self.cond_layer_idx = self.n_layers
|
| 64 |
+
if "gin_channels" in kwargs:
|
| 65 |
+
self.gin_channels = kwargs["gin_channels"]
|
| 66 |
+
if self.gin_channels != 0:
|
| 67 |
+
self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
|
| 68 |
+
# vits2 says 3rd block, so idx is 2 by default
|
| 69 |
+
self.cond_layer_idx = (
|
| 70 |
+
kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
|
| 71 |
+
)
|
| 72 |
+
logging.debug(self.gin_channels, self.cond_layer_idx)
|
| 73 |
+
assert (
|
| 74 |
+
self.cond_layer_idx < self.n_layers
|
| 75 |
+
), "cond_layer_idx should be less than n_layers"
|
| 76 |
+
self.drop = nn.Dropout(p_dropout)
|
| 77 |
+
self.attn_layers = nn.ModuleList()
|
| 78 |
+
self.norm_layers_1 = nn.ModuleList()
|
| 79 |
+
self.ffn_layers = nn.ModuleList()
|
| 80 |
+
self.norm_layers_2 = nn.ModuleList()
|
| 81 |
+
for i in range(self.n_layers):
|
| 82 |
+
self.attn_layers.append(
|
| 83 |
+
MultiHeadAttention(
|
| 84 |
+
hidden_channels,
|
| 85 |
+
hidden_channels,
|
| 86 |
+
n_heads,
|
| 87 |
+
p_dropout=p_dropout,
|
| 88 |
+
window_size=window_size,
|
| 89 |
+
)
|
| 90 |
+
)
|
| 91 |
+
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
| 92 |
+
self.ffn_layers.append(
|
| 93 |
+
FFN(
|
| 94 |
+
hidden_channels,
|
| 95 |
+
hidden_channels,
|
| 96 |
+
filter_channels,
|
| 97 |
+
kernel_size,
|
| 98 |
+
p_dropout=p_dropout,
|
| 99 |
+
)
|
| 100 |
+
)
|
| 101 |
+
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
| 102 |
+
|
| 103 |
+
def forward(self, x, x_mask, g=None):
|
| 104 |
+
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
| 105 |
+
x = x * x_mask
|
| 106 |
+
for i in range(self.n_layers):
|
| 107 |
+
if i == self.cond_layer_idx and g is not None:
|
| 108 |
+
g = self.spk_emb_linear(g.transpose(1, 2))
|
| 109 |
+
g = g.transpose(1, 2)
|
| 110 |
+
x = x + g
|
| 111 |
+
x = x * x_mask
|
| 112 |
+
y = self.attn_layers[i](x, x, attn_mask)
|
| 113 |
+
y = self.drop(y)
|
| 114 |
+
x = self.norm_layers_1[i](x + y)
|
| 115 |
+
|
| 116 |
+
y = self.ffn_layers[i](x, x_mask)
|
| 117 |
+
y = self.drop(y)
|
| 118 |
+
x = self.norm_layers_2[i](x + y)
|
| 119 |
+
x = x * x_mask
|
| 120 |
+
return x
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class MultiHeadAttention(nn.Module):
|
| 124 |
+
def __init__(
|
| 125 |
+
self,
|
| 126 |
+
channels,
|
| 127 |
+
out_channels,
|
| 128 |
+
n_heads,
|
| 129 |
+
p_dropout=0.0,
|
| 130 |
+
window_size=None,
|
| 131 |
+
heads_share=True,
|
| 132 |
+
block_length=None,
|
| 133 |
+
proximal_bias=False,
|
| 134 |
+
proximal_init=False,
|
| 135 |
+
):
|
| 136 |
+
super().__init__()
|
| 137 |
+
assert channels % n_heads == 0
|
| 138 |
+
|
| 139 |
+
self.channels = channels
|
| 140 |
+
self.out_channels = out_channels
|
| 141 |
+
self.n_heads = n_heads
|
| 142 |
+
self.p_dropout = p_dropout
|
| 143 |
+
self.window_size = window_size
|
| 144 |
+
self.heads_share = heads_share
|
| 145 |
+
self.block_length = block_length
|
| 146 |
+
self.proximal_bias = proximal_bias
|
| 147 |
+
self.proximal_init = proximal_init
|
| 148 |
+
self.attn = None
|
| 149 |
+
|
| 150 |
+
self.k_channels = channels // n_heads
|
| 151 |
+
self.conv_q = nn.Conv1d(channels, channels, 1)
|
| 152 |
+
self.conv_k = nn.Conv1d(channels, channels, 1)
|
| 153 |
+
self.conv_v = nn.Conv1d(channels, channels, 1)
|
| 154 |
+
self.conv_o = nn.Conv1d(channels, out_channels, 1)
|
| 155 |
+
self.drop = nn.Dropout(p_dropout)
|
| 156 |
+
|
| 157 |
+
if window_size is not None:
|
| 158 |
+
n_heads_rel = 1 if heads_share else n_heads
|
| 159 |
+
rel_stddev = self.k_channels**-0.5
|
| 160 |
+
self.emb_rel_k = nn.Parameter(
|
| 161 |
+
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
| 162 |
+
* rel_stddev
|
| 163 |
+
)
|
| 164 |
+
self.emb_rel_v = nn.Parameter(
|
| 165 |
+
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
| 166 |
+
* rel_stddev
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
nn.init.xavier_uniform_(self.conv_q.weight)
|
| 170 |
+
nn.init.xavier_uniform_(self.conv_k.weight)
|
| 171 |
+
nn.init.xavier_uniform_(self.conv_v.weight)
|
| 172 |
+
if proximal_init:
|
| 173 |
+
with torch.no_grad():
|
| 174 |
+
self.conv_k.weight.copy_(self.conv_q.weight)
|
| 175 |
+
self.conv_k.bias.copy_(self.conv_q.bias)
|
| 176 |
+
|
| 177 |
+
def forward(self, x, c, attn_mask=None):
|
| 178 |
+
q = self.conv_q(x)
|
| 179 |
+
k = self.conv_k(c)
|
| 180 |
+
v = self.conv_v(c)
|
| 181 |
+
|
| 182 |
+
x, self.attn = self.attention(q, k, v, mask=attn_mask)
|
| 183 |
+
|
| 184 |
+
x = self.conv_o(x)
|
| 185 |
+
return x
|
| 186 |
+
|
| 187 |
+
def attention(self, query, key, value, mask=None):
|
| 188 |
+
# reshape [b, d, t] -> [b, n_h, t, d_k]
|
| 189 |
+
b, d, t_s, t_t = (*key.size(), query.size(2))
|
| 190 |
+
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
|
| 191 |
+
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
| 192 |
+
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
| 193 |
+
|
| 194 |
+
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
|
| 195 |
+
if self.window_size is not None:
|
| 196 |
+
assert (
|
| 197 |
+
t_s == t_t
|
| 198 |
+
), "Relative attention is only available for self-attention."
|
| 199 |
+
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
|
| 200 |
+
rel_logits = self._matmul_with_relative_keys(
|
| 201 |
+
query / math.sqrt(self.k_channels), key_relative_embeddings
|
| 202 |
+
)
|
| 203 |
+
scores_local = self._relative_position_to_absolute_position(rel_logits)
|
| 204 |
+
scores = scores + scores_local
|
| 205 |
+
if self.proximal_bias:
|
| 206 |
+
assert t_s == t_t, "Proximal bias is only available for self-attention."
|
| 207 |
+
scores = scores + self._attention_bias_proximal(t_s).to(
|
| 208 |
+
device=scores.device, dtype=scores.dtype
|
| 209 |
+
)
|
| 210 |
+
if mask is not None:
|
| 211 |
+
scores = scores.masked_fill(mask == 0, -1e4)
|
| 212 |
+
if self.block_length is not None:
|
| 213 |
+
assert (
|
| 214 |
+
t_s == t_t
|
| 215 |
+
), "Local attention is only available for self-attention."
|
| 216 |
+
block_mask = (
|
| 217 |
+
torch.ones_like(scores)
|
| 218 |
+
.triu(-self.block_length)
|
| 219 |
+
.tril(self.block_length)
|
| 220 |
+
)
|
| 221 |
+
scores = scores.masked_fill(block_mask == 0, -1e4)
|
| 222 |
+
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
|
| 223 |
+
p_attn = self.drop(p_attn)
|
| 224 |
+
output = torch.matmul(p_attn, value)
|
| 225 |
+
if self.window_size is not None:
|
| 226 |
+
relative_weights = self._absolute_position_to_relative_position(p_attn)
|
| 227 |
+
value_relative_embeddings = self._get_relative_embeddings(
|
| 228 |
+
self.emb_rel_v, t_s
|
| 229 |
+
)
|
| 230 |
+
output = output + self._matmul_with_relative_values(
|
| 231 |
+
relative_weights, value_relative_embeddings
|
| 232 |
+
)
|
| 233 |
+
output = (
|
| 234 |
+
output.transpose(2, 3).contiguous().view(b, d, t_t)
|
| 235 |
+
) # [b, n_h, t_t, d_k] -> [b, d, t_t]
|
| 236 |
+
return output, p_attn
|
| 237 |
+
|
| 238 |
+
def _matmul_with_relative_values(self, x, y):
|
| 239 |
+
"""
|
| 240 |
+
x: [b, h, l, m]
|
| 241 |
+
y: [h or 1, m, d]
|
| 242 |
+
ret: [b, h, l, d]
|
| 243 |
+
"""
|
| 244 |
+
ret = torch.matmul(x, y.unsqueeze(0))
|
| 245 |
+
return ret
|
| 246 |
+
|
| 247 |
+
def _matmul_with_relative_keys(self, x, y):
|
| 248 |
+
"""
|
| 249 |
+
x: [b, h, l, d]
|
| 250 |
+
y: [h or 1, m, d]
|
| 251 |
+
ret: [b, h, l, m]
|
| 252 |
+
"""
|
| 253 |
+
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
|
| 254 |
+
return ret
|
| 255 |
+
|
| 256 |
+
def _get_relative_embeddings(self, relative_embeddings, length):
|
| 257 |
+
max_relative_position = 2 * self.window_size + 1
|
| 258 |
+
# Pad first before slice to avoid using cond ops.
|
| 259 |
+
pad_length = max(length - (self.window_size + 1), 0)
|
| 260 |
+
slice_start_position = max((self.window_size + 1) - length, 0)
|
| 261 |
+
slice_end_position = slice_start_position + 2 * length - 1
|
| 262 |
+
if pad_length > 0:
|
| 263 |
+
padded_relative_embeddings = F.pad(
|
| 264 |
+
relative_embeddings,
|
| 265 |
+
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
|
| 266 |
+
)
|
| 267 |
+
else:
|
| 268 |
+
padded_relative_embeddings = relative_embeddings
|
| 269 |
+
used_relative_embeddings = padded_relative_embeddings[
|
| 270 |
+
:, slice_start_position:slice_end_position
|
| 271 |
+
]
|
| 272 |
+
return used_relative_embeddings
|
| 273 |
+
|
| 274 |
+
def _relative_position_to_absolute_position(self, x):
|
| 275 |
+
"""
|
| 276 |
+
x: [b, h, l, 2*l-1]
|
| 277 |
+
ret: [b, h, l, l]
|
| 278 |
+
"""
|
| 279 |
+
batch, heads, length, _ = x.size()
|
| 280 |
+
# Concat columns of pad to shift from relative to absolute indexing.
|
| 281 |
+
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
|
| 282 |
+
|
| 283 |
+
# Concat extra elements so to add up to shape (len+1, 2*len-1).
|
| 284 |
+
x_flat = x.view([batch, heads, length * 2 * length])
|
| 285 |
+
x_flat = F.pad(
|
| 286 |
+
x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
# Reshape and slice out the padded elements.
|
| 290 |
+
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
|
| 291 |
+
:, :, :length, length - 1 :
|
| 292 |
+
]
|
| 293 |
+
return x_final
|
| 294 |
+
|
| 295 |
+
def _absolute_position_to_relative_position(self, x):
|
| 296 |
+
"""
|
| 297 |
+
x: [b, h, l, l]
|
| 298 |
+
ret: [b, h, l, 2*l-1]
|
| 299 |
+
"""
|
| 300 |
+
batch, heads, length, _ = x.size()
|
| 301 |
+
# padd along column
|
| 302 |
+
x = F.pad(
|
| 303 |
+
x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
|
| 304 |
+
)
|
| 305 |
+
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
|
| 306 |
+
# add 0's in the beginning that will skew the elements after reshape
|
| 307 |
+
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
|
| 308 |
+
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
|
| 309 |
+
return x_final
|
| 310 |
+
|
| 311 |
+
def _attention_bias_proximal(self, length):
|
| 312 |
+
"""Bias for self-attention to encourage attention to close positions.
|
| 313 |
+
Args:
|
| 314 |
+
length: an integer scalar.
|
| 315 |
+
Returns:
|
| 316 |
+
a Tensor with shape [1, 1, length, length]
|
| 317 |
+
"""
|
| 318 |
+
r = torch.arange(length, dtype=torch.float32)
|
| 319 |
+
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
|
| 320 |
+
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
class FFN(nn.Module):
|
| 324 |
+
def __init__(
|
| 325 |
+
self,
|
| 326 |
+
in_channels,
|
| 327 |
+
out_channels,
|
| 328 |
+
filter_channels,
|
| 329 |
+
kernel_size,
|
| 330 |
+
p_dropout=0.0,
|
| 331 |
+
activation=None,
|
| 332 |
+
causal=False,
|
| 333 |
+
):
|
| 334 |
+
super().__init__()
|
| 335 |
+
self.in_channels = in_channels
|
| 336 |
+
self.out_channels = out_channels
|
| 337 |
+
self.filter_channels = filter_channels
|
| 338 |
+
self.kernel_size = kernel_size
|
| 339 |
+
self.p_dropout = p_dropout
|
| 340 |
+
self.activation = activation
|
| 341 |
+
self.causal = causal
|
| 342 |
+
|
| 343 |
+
if causal:
|
| 344 |
+
self.padding = self._causal_padding
|
| 345 |
+
else:
|
| 346 |
+
self.padding = self._same_padding
|
| 347 |
+
|
| 348 |
+
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
|
| 349 |
+
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
|
| 350 |
+
self.drop = nn.Dropout(p_dropout)
|
| 351 |
+
|
| 352 |
+
def forward(self, x, x_mask):
|
| 353 |
+
x = self.conv_1(self.padding(x * x_mask))
|
| 354 |
+
if self.activation == "gelu":
|
| 355 |
+
x = x * torch.sigmoid(1.702 * x)
|
| 356 |
+
else:
|
| 357 |
+
x = torch.relu(x)
|
| 358 |
+
x = self.drop(x)
|
| 359 |
+
x = self.conv_2(self.padding(x * x_mask))
|
| 360 |
+
return x * x_mask
|
| 361 |
+
|
| 362 |
+
def _causal_padding(self, x):
|
| 363 |
+
if self.kernel_size == 1:
|
| 364 |
+
return x
|
| 365 |
+
pad_l = self.kernel_size - 1
|
| 366 |
+
pad_r = 0
|
| 367 |
+
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
| 368 |
+
x = F.pad(x, commons.convert_pad_shape(padding))
|
| 369 |
+
return x
|
| 370 |
+
|
| 371 |
+
def _same_padding(self, x):
|
| 372 |
+
if self.kernel_size == 1:
|
| 373 |
+
return x
|
| 374 |
+
pad_l = (self.kernel_size - 1) // 2
|
| 375 |
+
pad_r = self.kernel_size // 2
|
| 376 |
+
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
| 377 |
+
x = F.pad(x, commons.convert_pad_shape(padding))
|
| 378 |
+
return x
|
onnx_modules/V230/models_onnx.py
ADDED
|
@@ -0,0 +1,1061 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
from torch.nn import functional as F
|
| 5 |
+
|
| 6 |
+
import commons
|
| 7 |
+
import modules
|
| 8 |
+
from . import attentions_onnx
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
from torch.nn import Conv1d, ConvTranspose1d, Conv2d
|
| 12 |
+
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
| 13 |
+
|
| 14 |
+
from commons import init_weights, get_padding
|
| 15 |
+
from .text import symbols, num_tones, num_languages
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class DurationDiscriminator(nn.Module): # vits2
|
| 21 |
+
def __init__(
|
| 22 |
+
self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
|
| 23 |
+
):
|
| 24 |
+
super().__init__()
|
| 25 |
+
|
| 26 |
+
self.in_channels = in_channels
|
| 27 |
+
self.filter_channels = filter_channels
|
| 28 |
+
self.kernel_size = kernel_size
|
| 29 |
+
self.p_dropout = p_dropout
|
| 30 |
+
self.gin_channels = gin_channels
|
| 31 |
+
|
| 32 |
+
self.drop = nn.Dropout(p_dropout)
|
| 33 |
+
self.conv_1 = nn.Conv1d(
|
| 34 |
+
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
| 35 |
+
)
|
| 36 |
+
self.norm_1 = modules.LayerNorm(filter_channels)
|
| 37 |
+
self.conv_2 = nn.Conv1d(
|
| 38 |
+
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
| 39 |
+
)
|
| 40 |
+
self.norm_2 = modules.LayerNorm(filter_channels)
|
| 41 |
+
self.dur_proj = nn.Conv1d(1, filter_channels, 1)
|
| 42 |
+
|
| 43 |
+
self.LSTM = nn.LSTM(
|
| 44 |
+
2 * filter_channels, filter_channels, batch_first=True, bidirectional=True
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
if gin_channels != 0:
|
| 48 |
+
self.cond = nn.Conv1d(gin_channels, in_channels, 1)
|
| 49 |
+
|
| 50 |
+
self.output_layer = nn.Sequential(
|
| 51 |
+
nn.Linear(2 * filter_channels, 1), nn.Sigmoid()
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
def forward_probability(self, x, dur):
|
| 55 |
+
dur = self.dur_proj(dur)
|
| 56 |
+
x = torch.cat([x, dur], dim=1)
|
| 57 |
+
x = x.transpose(1, 2)
|
| 58 |
+
x, _ = self.LSTM(x)
|
| 59 |
+
output_prob = self.output_layer(x)
|
| 60 |
+
return output_prob
|
| 61 |
+
|
| 62 |
+
def forward(self, x, x_mask, dur_r, dur_hat, g=None):
|
| 63 |
+
x = torch.detach(x)
|
| 64 |
+
if g is not None:
|
| 65 |
+
g = torch.detach(g)
|
| 66 |
+
x = x + self.cond(g)
|
| 67 |
+
x = self.conv_1(x * x_mask)
|
| 68 |
+
x = torch.relu(x)
|
| 69 |
+
x = self.norm_1(x)
|
| 70 |
+
x = self.drop(x)
|
| 71 |
+
x = self.conv_2(x * x_mask)
|
| 72 |
+
x = torch.relu(x)
|
| 73 |
+
x = self.norm_2(x)
|
| 74 |
+
x = self.drop(x)
|
| 75 |
+
|
| 76 |
+
output_probs = []
|
| 77 |
+
for dur in [dur_r, dur_hat]:
|
| 78 |
+
output_prob = self.forward_probability(x, dur)
|
| 79 |
+
output_probs.append(output_prob)
|
| 80 |
+
|
| 81 |
+
return output_probs
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class TransformerCouplingBlock(nn.Module):
|
| 85 |
+
def __init__(
|
| 86 |
+
self,
|
| 87 |
+
channels,
|
| 88 |
+
hidden_channels,
|
| 89 |
+
filter_channels,
|
| 90 |
+
n_heads,
|
| 91 |
+
n_layers,
|
| 92 |
+
kernel_size,
|
| 93 |
+
p_dropout,
|
| 94 |
+
n_flows=4,
|
| 95 |
+
gin_channels=0,
|
| 96 |
+
share_parameter=False,
|
| 97 |
+
):
|
| 98 |
+
super().__init__()
|
| 99 |
+
self.channels = channels
|
| 100 |
+
self.hidden_channels = hidden_channels
|
| 101 |
+
self.kernel_size = kernel_size
|
| 102 |
+
self.n_layers = n_layers
|
| 103 |
+
self.n_flows = n_flows
|
| 104 |
+
self.gin_channels = gin_channels
|
| 105 |
+
|
| 106 |
+
self.flows = nn.ModuleList()
|
| 107 |
+
|
| 108 |
+
self.wn = (
|
| 109 |
+
attentions_onnx.FFT(
|
| 110 |
+
hidden_channels,
|
| 111 |
+
filter_channels,
|
| 112 |
+
n_heads,
|
| 113 |
+
n_layers,
|
| 114 |
+
kernel_size,
|
| 115 |
+
p_dropout,
|
| 116 |
+
isflow=True,
|
| 117 |
+
gin_channels=self.gin_channels,
|
| 118 |
+
)
|
| 119 |
+
if share_parameter
|
| 120 |
+
else None
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
for i in range(n_flows):
|
| 124 |
+
self.flows.append(
|
| 125 |
+
modules.TransformerCouplingLayer(
|
| 126 |
+
channels,
|
| 127 |
+
hidden_channels,
|
| 128 |
+
kernel_size,
|
| 129 |
+
n_layers,
|
| 130 |
+
n_heads,
|
| 131 |
+
p_dropout,
|
| 132 |
+
filter_channels,
|
| 133 |
+
mean_only=True,
|
| 134 |
+
wn_sharing_parameter=self.wn,
|
| 135 |
+
gin_channels=self.gin_channels,
|
| 136 |
+
)
|
| 137 |
+
)
|
| 138 |
+
self.flows.append(modules.Flip())
|
| 139 |
+
|
| 140 |
+
def forward(self, x, x_mask, g=None, reverse=True):
|
| 141 |
+
if not reverse:
|
| 142 |
+
for flow in self.flows:
|
| 143 |
+
x, _ = flow(x, x_mask, g=g, reverse=reverse)
|
| 144 |
+
else:
|
| 145 |
+
for flow in reversed(self.flows):
|
| 146 |
+
x = flow(x, x_mask, g=g, reverse=reverse)
|
| 147 |
+
return x
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class StochasticDurationPredictor(nn.Module):
|
| 151 |
+
def __init__(
|
| 152 |
+
self,
|
| 153 |
+
in_channels,
|
| 154 |
+
filter_channels,
|
| 155 |
+
kernel_size,
|
| 156 |
+
p_dropout,
|
| 157 |
+
n_flows=4,
|
| 158 |
+
gin_channels=0,
|
| 159 |
+
):
|
| 160 |
+
super().__init__()
|
| 161 |
+
filter_channels = in_channels # it needs to be removed from future version.
|
| 162 |
+
self.in_channels = in_channels
|
| 163 |
+
self.filter_channels = filter_channels
|
| 164 |
+
self.kernel_size = kernel_size
|
| 165 |
+
self.p_dropout = p_dropout
|
| 166 |
+
self.n_flows = n_flows
|
| 167 |
+
self.gin_channels = gin_channels
|
| 168 |
+
|
| 169 |
+
self.log_flow = modules.Log()
|
| 170 |
+
self.flows = nn.ModuleList()
|
| 171 |
+
self.flows.append(modules.ElementwiseAffine(2))
|
| 172 |
+
for i in range(n_flows):
|
| 173 |
+
self.flows.append(
|
| 174 |
+
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
|
| 175 |
+
)
|
| 176 |
+
self.flows.append(modules.Flip())
|
| 177 |
+
|
| 178 |
+
self.post_pre = nn.Conv1d(1, filter_channels, 1)
|
| 179 |
+
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
| 180 |
+
self.post_convs = modules.DDSConv(
|
| 181 |
+
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
|
| 182 |
+
)
|
| 183 |
+
self.post_flows = nn.ModuleList()
|
| 184 |
+
self.post_flows.append(modules.ElementwiseAffine(2))
|
| 185 |
+
for i in range(4):
|
| 186 |
+
self.post_flows.append(
|
| 187 |
+
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
|
| 188 |
+
)
|
| 189 |
+
self.post_flows.append(modules.Flip())
|
| 190 |
+
|
| 191 |
+
self.pre = nn.Conv1d(in_channels, filter_channels, 1)
|
| 192 |
+
self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
| 193 |
+
self.convs = modules.DDSConv(
|
| 194 |
+
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
|
| 195 |
+
)
|
| 196 |
+
if gin_channels != 0:
|
| 197 |
+
self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
|
| 198 |
+
|
| 199 |
+
def forward(self, x, x_mask, z, g=None):
|
| 200 |
+
x = torch.detach(x)
|
| 201 |
+
x = self.pre(x)
|
| 202 |
+
if g is not None:
|
| 203 |
+
g = torch.detach(g)
|
| 204 |
+
x = x + self.cond(g)
|
| 205 |
+
x = self.convs(x, x_mask)
|
| 206 |
+
x = self.proj(x) * x_mask
|
| 207 |
+
|
| 208 |
+
flows = list(reversed(self.flows))
|
| 209 |
+
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
|
| 210 |
+
for flow in flows:
|
| 211 |
+
z = flow(z, x_mask, g=x, reverse=True)
|
| 212 |
+
z0, z1 = torch.split(z, [1, 1], 1)
|
| 213 |
+
logw = z0
|
| 214 |
+
return logw
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class DurationPredictor(nn.Module):
|
| 218 |
+
def __init__(
|
| 219 |
+
self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
|
| 220 |
+
):
|
| 221 |
+
super().__init__()
|
| 222 |
+
|
| 223 |
+
self.in_channels = in_channels
|
| 224 |
+
self.filter_channels = filter_channels
|
| 225 |
+
self.kernel_size = kernel_size
|
| 226 |
+
self.p_dropout = p_dropout
|
| 227 |
+
self.gin_channels = gin_channels
|
| 228 |
+
|
| 229 |
+
self.drop = nn.Dropout(p_dropout)
|
| 230 |
+
self.conv_1 = nn.Conv1d(
|
| 231 |
+
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
| 232 |
+
)
|
| 233 |
+
self.norm_1 = modules.LayerNorm(filter_channels)
|
| 234 |
+
self.conv_2 = nn.Conv1d(
|
| 235 |
+
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
| 236 |
+
)
|
| 237 |
+
self.norm_2 = modules.LayerNorm(filter_channels)
|
| 238 |
+
self.proj = nn.Conv1d(filter_channels, 1, 1)
|
| 239 |
+
|
| 240 |
+
if gin_channels != 0:
|
| 241 |
+
self.cond = nn.Conv1d(gin_channels, in_channels, 1)
|
| 242 |
+
|
| 243 |
+
def forward(self, x, x_mask, g=None):
|
| 244 |
+
x = torch.detach(x)
|
| 245 |
+
if g is not None:
|
| 246 |
+
g = torch.detach(g)
|
| 247 |
+
x = x + self.cond(g)
|
| 248 |
+
x = self.conv_1(x * x_mask)
|
| 249 |
+
x = torch.relu(x)
|
| 250 |
+
x = self.norm_1(x)
|
| 251 |
+
x = self.drop(x)
|
| 252 |
+
x = self.conv_2(x * x_mask)
|
| 253 |
+
x = torch.relu(x)
|
| 254 |
+
x = self.norm_2(x)
|
| 255 |
+
x = self.drop(x)
|
| 256 |
+
x = self.proj(x * x_mask)
|
| 257 |
+
return x * x_mask
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
class Bottleneck(nn.Sequential):
|
| 261 |
+
def __init__(self, in_dim, hidden_dim):
|
| 262 |
+
c_fc1 = nn.Linear(in_dim, hidden_dim, bias=False)
|
| 263 |
+
c_fc2 = nn.Linear(in_dim, hidden_dim, bias=False)
|
| 264 |
+
super().__init__(*[c_fc1, c_fc2])
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
class Block(nn.Module):
|
| 268 |
+
def __init__(self, in_dim, hidden_dim) -> None:
|
| 269 |
+
super().__init__()
|
| 270 |
+
self.norm = nn.LayerNorm(in_dim)
|
| 271 |
+
self.mlp = MLP(in_dim, hidden_dim)
|
| 272 |
+
|
| 273 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 274 |
+
x = x + self.mlp(self.norm(x))
|
| 275 |
+
return x
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
class MLP(nn.Module):
|
| 279 |
+
def __init__(self, in_dim, hidden_dim):
|
| 280 |
+
super().__init__()
|
| 281 |
+
self.c_fc1 = nn.Linear(in_dim, hidden_dim, bias=False)
|
| 282 |
+
self.c_fc2 = nn.Linear(in_dim, hidden_dim, bias=False)
|
| 283 |
+
self.c_proj = nn.Linear(hidden_dim, in_dim, bias=False)
|
| 284 |
+
|
| 285 |
+
def forward(self, x: torch.Tensor):
|
| 286 |
+
x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
|
| 287 |
+
x = self.c_proj(x)
|
| 288 |
+
return x
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
class TextEncoder(nn.Module):
|
| 292 |
+
def __init__(
|
| 293 |
+
self,
|
| 294 |
+
n_vocab,
|
| 295 |
+
out_channels,
|
| 296 |
+
hidden_channels,
|
| 297 |
+
filter_channels,
|
| 298 |
+
n_heads,
|
| 299 |
+
n_layers,
|
| 300 |
+
kernel_size,
|
| 301 |
+
p_dropout,
|
| 302 |
+
gin_channels=0,
|
| 303 |
+
):
|
| 304 |
+
super().__init__()
|
| 305 |
+
self.n_vocab = n_vocab
|
| 306 |
+
self.out_channels = out_channels
|
| 307 |
+
self.hidden_channels = hidden_channels
|
| 308 |
+
self.filter_channels = filter_channels
|
| 309 |
+
self.n_heads = n_heads
|
| 310 |
+
self.n_layers = n_layers
|
| 311 |
+
self.kernel_size = kernel_size
|
| 312 |
+
self.p_dropout = p_dropout
|
| 313 |
+
self.gin_channels = gin_channels
|
| 314 |
+
self.emb = nn.Embedding(len(symbols), hidden_channels)
|
| 315 |
+
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
|
| 316 |
+
self.tone_emb = nn.Embedding(num_tones, hidden_channels)
|
| 317 |
+
nn.init.normal_(self.tone_emb.weight, 0.0, hidden_channels**-0.5)
|
| 318 |
+
self.language_emb = nn.Embedding(num_languages, hidden_channels)
|
| 319 |
+
nn.init.normal_(self.language_emb.weight, 0.0, hidden_channels**-0.5)
|
| 320 |
+
self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
|
| 321 |
+
self.ja_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
|
| 322 |
+
self.en_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
|
| 323 |
+
|
| 324 |
+
self.encoder = attentions_onnx.Encoder(
|
| 325 |
+
hidden_channels,
|
| 326 |
+
filter_channels,
|
| 327 |
+
n_heads,
|
| 328 |
+
n_layers,
|
| 329 |
+
kernel_size,
|
| 330 |
+
p_dropout,
|
| 331 |
+
gin_channels=self.gin_channels,
|
| 332 |
+
)
|
| 333 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
| 334 |
+
|
| 335 |
+
def forward(self, x, x_lengths, tone, language, bert, ja_bert, en_bert, g=None):
|
| 336 |
+
x_mask = torch.ones_like(x).unsqueeze(0)
|
| 337 |
+
bert_emb = self.bert_proj(bert.transpose(0, 1).unsqueeze(0)).transpose(1, 2)
|
| 338 |
+
ja_bert_emb = self.ja_bert_proj(ja_bert.transpose(0, 1).unsqueeze(0)).transpose(1, 2)
|
| 339 |
+
en_bert_emb = self.en_bert_proj(en_bert.transpose(0, 1).unsqueeze(0)).transpose(1, 2)
|
| 340 |
+
x = (
|
| 341 |
+
self.emb(x)
|
| 342 |
+
+ self.tone_emb(tone)
|
| 343 |
+
+ self.language_emb(language)
|
| 344 |
+
+ bert_emb
|
| 345 |
+
+ ja_bert_emb
|
| 346 |
+
+ en_bert_emb
|
| 347 |
+
) * math.sqrt(
|
| 348 |
+
self.hidden_channels
|
| 349 |
+
) # [b, t, h]
|
| 350 |
+
x = torch.transpose(x, 1, -1) # [b, h, t]
|
| 351 |
+
x_mask = x_mask.to(x.dtype)
|
| 352 |
+
|
| 353 |
+
x = self.encoder(x * x_mask, x_mask, g=g)
|
| 354 |
+
stats = self.proj(x) * x_mask
|
| 355 |
+
|
| 356 |
+
m, logs = torch.split(stats, self.out_channels, dim=1)
|
| 357 |
+
return x, m, logs, x_mask
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
class ResidualCouplingBlock(nn.Module):
|
| 361 |
+
def __init__(
|
| 362 |
+
self,
|
| 363 |
+
channels,
|
| 364 |
+
hidden_channels,
|
| 365 |
+
kernel_size,
|
| 366 |
+
dilation_rate,
|
| 367 |
+
n_layers,
|
| 368 |
+
n_flows=4,
|
| 369 |
+
gin_channels=0,
|
| 370 |
+
):
|
| 371 |
+
super().__init__()
|
| 372 |
+
self.channels = channels
|
| 373 |
+
self.hidden_channels = hidden_channels
|
| 374 |
+
self.kernel_size = kernel_size
|
| 375 |
+
self.dilation_rate = dilation_rate
|
| 376 |
+
self.n_layers = n_layers
|
| 377 |
+
self.n_flows = n_flows
|
| 378 |
+
self.gin_channels = gin_channels
|
| 379 |
+
|
| 380 |
+
self.flows = nn.ModuleList()
|
| 381 |
+
for i in range(n_flows):
|
| 382 |
+
self.flows.append(
|
| 383 |
+
modules.ResidualCouplingLayer(
|
| 384 |
+
channels,
|
| 385 |
+
hidden_channels,
|
| 386 |
+
kernel_size,
|
| 387 |
+
dilation_rate,
|
| 388 |
+
n_layers,
|
| 389 |
+
gin_channels=gin_channels,
|
| 390 |
+
mean_only=True,
|
| 391 |
+
)
|
| 392 |
+
)
|
| 393 |
+
self.flows.append(modules.Flip())
|
| 394 |
+
|
| 395 |
+
def forward(self, x, x_mask, g=None, reverse=True):
|
| 396 |
+
if not reverse:
|
| 397 |
+
for flow in self.flows:
|
| 398 |
+
x, _ = flow(x, x_mask, g=g, reverse=reverse)
|
| 399 |
+
else:
|
| 400 |
+
for flow in reversed(self.flows):
|
| 401 |
+
x = flow(x, x_mask, g=g, reverse=reverse)
|
| 402 |
+
return x
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
class PosteriorEncoder(nn.Module):
|
| 406 |
+
def __init__(
|
| 407 |
+
self,
|
| 408 |
+
in_channels,
|
| 409 |
+
out_channels,
|
| 410 |
+
hidden_channels,
|
| 411 |
+
kernel_size,
|
| 412 |
+
dilation_rate,
|
| 413 |
+
n_layers,
|
| 414 |
+
gin_channels=0,
|
| 415 |
+
):
|
| 416 |
+
super().__init__()
|
| 417 |
+
self.in_channels = in_channels
|
| 418 |
+
self.out_channels = out_channels
|
| 419 |
+
self.hidden_channels = hidden_channels
|
| 420 |
+
self.kernel_size = kernel_size
|
| 421 |
+
self.dilation_rate = dilation_rate
|
| 422 |
+
self.n_layers = n_layers
|
| 423 |
+
self.gin_channels = gin_channels
|
| 424 |
+
|
| 425 |
+
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
| 426 |
+
self.enc = modules.WN(
|
| 427 |
+
hidden_channels,
|
| 428 |
+
kernel_size,
|
| 429 |
+
dilation_rate,
|
| 430 |
+
n_layers,
|
| 431 |
+
gin_channels=gin_channels,
|
| 432 |
+
)
|
| 433 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
| 434 |
+
|
| 435 |
+
def forward(self, x, x_lengths, g=None):
|
| 436 |
+
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
|
| 437 |
+
x.dtype
|
| 438 |
+
)
|
| 439 |
+
x = self.pre(x) * x_mask
|
| 440 |
+
x = self.enc(x, x_mask, g=g)
|
| 441 |
+
stats = self.proj(x) * x_mask
|
| 442 |
+
m, logs = torch.split(stats, self.out_channels, dim=1)
|
| 443 |
+
z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
|
| 444 |
+
return z, m, logs, x_mask
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
class Generator(torch.nn.Module):
|
| 448 |
+
def __init__(
|
| 449 |
+
self,
|
| 450 |
+
initial_channel,
|
| 451 |
+
resblock,
|
| 452 |
+
resblock_kernel_sizes,
|
| 453 |
+
resblock_dilation_sizes,
|
| 454 |
+
upsample_rates,
|
| 455 |
+
upsample_initial_channel,
|
| 456 |
+
upsample_kernel_sizes,
|
| 457 |
+
gin_channels=0,
|
| 458 |
+
):
|
| 459 |
+
super(Generator, self).__init__()
|
| 460 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
| 461 |
+
self.num_upsamples = len(upsample_rates)
|
| 462 |
+
self.conv_pre = Conv1d(
|
| 463 |
+
initial_channel, upsample_initial_channel, 7, 1, padding=3
|
| 464 |
+
)
|
| 465 |
+
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
|
| 466 |
+
|
| 467 |
+
self.ups = nn.ModuleList()
|
| 468 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
| 469 |
+
self.ups.append(
|
| 470 |
+
weight_norm(
|
| 471 |
+
ConvTranspose1d(
|
| 472 |
+
upsample_initial_channel // (2**i),
|
| 473 |
+
upsample_initial_channel // (2 ** (i + 1)),
|
| 474 |
+
k,
|
| 475 |
+
u,
|
| 476 |
+
padding=(k - u) // 2,
|
| 477 |
+
)
|
| 478 |
+
)
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
self.resblocks = nn.ModuleList()
|
| 482 |
+
for i in range(len(self.ups)):
|
| 483 |
+
ch = upsample_initial_channel // (2 ** (i + 1))
|
| 484 |
+
for j, (k, d) in enumerate(
|
| 485 |
+
zip(resblock_kernel_sizes, resblock_dilation_sizes)
|
| 486 |
+
):
|
| 487 |
+
self.resblocks.append(resblock(ch, k, d))
|
| 488 |
+
|
| 489 |
+
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
| 490 |
+
self.ups.apply(init_weights)
|
| 491 |
+
|
| 492 |
+
if gin_channels != 0:
|
| 493 |
+
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
| 494 |
+
|
| 495 |
+
def forward(self, x, g=None):
|
| 496 |
+
x = self.conv_pre(x)
|
| 497 |
+
if g is not None:
|
| 498 |
+
x = x + self.cond(g)
|
| 499 |
+
|
| 500 |
+
for i in range(self.num_upsamples):
|
| 501 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
| 502 |
+
x = self.ups[i](x)
|
| 503 |
+
xs = None
|
| 504 |
+
for j in range(self.num_kernels):
|
| 505 |
+
if xs is None:
|
| 506 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
| 507 |
+
else:
|
| 508 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
| 509 |
+
x = xs / self.num_kernels
|
| 510 |
+
x = F.leaky_relu(x)
|
| 511 |
+
x = self.conv_post(x)
|
| 512 |
+
x = torch.tanh(x)
|
| 513 |
+
|
| 514 |
+
return x
|
| 515 |
+
|
| 516 |
+
def remove_weight_norm(self):
|
| 517 |
+
print("Removing weight norm...")
|
| 518 |
+
for layer in self.ups:
|
| 519 |
+
remove_weight_norm(layer)
|
| 520 |
+
for layer in self.resblocks:
|
| 521 |
+
layer.remove_weight_norm()
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
class DiscriminatorP(torch.nn.Module):
|
| 525 |
+
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
| 526 |
+
super(DiscriminatorP, self).__init__()
|
| 527 |
+
self.period = period
|
| 528 |
+
self.use_spectral_norm = use_spectral_norm
|
| 529 |
+
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
|
| 530 |
+
self.convs = nn.ModuleList(
|
| 531 |
+
[
|
| 532 |
+
norm_f(
|
| 533 |
+
Conv2d(
|
| 534 |
+
1,
|
| 535 |
+
32,
|
| 536 |
+
(kernel_size, 1),
|
| 537 |
+
(stride, 1),
|
| 538 |
+
padding=(get_padding(kernel_size, 1), 0),
|
| 539 |
+
)
|
| 540 |
+
),
|
| 541 |
+
norm_f(
|
| 542 |
+
Conv2d(
|
| 543 |
+
32,
|
| 544 |
+
128,
|
| 545 |
+
(kernel_size, 1),
|
| 546 |
+
(stride, 1),
|
| 547 |
+
padding=(get_padding(kernel_size, 1), 0),
|
| 548 |
+
)
|
| 549 |
+
),
|
| 550 |
+
norm_f(
|
| 551 |
+
Conv2d(
|
| 552 |
+
128,
|
| 553 |
+
512,
|
| 554 |
+
(kernel_size, 1),
|
| 555 |
+
(stride, 1),
|
| 556 |
+
padding=(get_padding(kernel_size, 1), 0),
|
| 557 |
+
)
|
| 558 |
+
),
|
| 559 |
+
norm_f(
|
| 560 |
+
Conv2d(
|
| 561 |
+
512,
|
| 562 |
+
1024,
|
| 563 |
+
(kernel_size, 1),
|
| 564 |
+
(stride, 1),
|
| 565 |
+
padding=(get_padding(kernel_size, 1), 0),
|
| 566 |
+
)
|
| 567 |
+
),
|
| 568 |
+
norm_f(
|
| 569 |
+
Conv2d(
|
| 570 |
+
1024,
|
| 571 |
+
1024,
|
| 572 |
+
(kernel_size, 1),
|
| 573 |
+
1,
|
| 574 |
+
padding=(get_padding(kernel_size, 1), 0),
|
| 575 |
+
)
|
| 576 |
+
),
|
| 577 |
+
]
|
| 578 |
+
)
|
| 579 |
+
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
| 580 |
+
|
| 581 |
+
def forward(self, x):
|
| 582 |
+
fmap = []
|
| 583 |
+
|
| 584 |
+
# 1d to 2d
|
| 585 |
+
b, c, t = x.shape
|
| 586 |
+
if t % self.period != 0: # pad first
|
| 587 |
+
n_pad = self.period - (t % self.period)
|
| 588 |
+
x = F.pad(x, (0, n_pad), "reflect")
|
| 589 |
+
t = t + n_pad
|
| 590 |
+
x = x.view(b, c, t // self.period, self.period)
|
| 591 |
+
|
| 592 |
+
for layer in self.convs:
|
| 593 |
+
x = layer(x)
|
| 594 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
| 595 |
+
fmap.append(x)
|
| 596 |
+
x = self.conv_post(x)
|
| 597 |
+
fmap.append(x)
|
| 598 |
+
x = torch.flatten(x, 1, -1)
|
| 599 |
+
|
| 600 |
+
return x, fmap
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
class DiscriminatorS(torch.nn.Module):
|
| 604 |
+
def __init__(self, use_spectral_norm=False):
|
| 605 |
+
super(DiscriminatorS, self).__init__()
|
| 606 |
+
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
|
| 607 |
+
self.convs = nn.ModuleList(
|
| 608 |
+
[
|
| 609 |
+
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
|
| 610 |
+
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
|
| 611 |
+
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
|
| 612 |
+
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
|
| 613 |
+
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
|
| 614 |
+
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
| 615 |
+
]
|
| 616 |
+
)
|
| 617 |
+
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
| 618 |
+
|
| 619 |
+
def forward(self, x):
|
| 620 |
+
fmap = []
|
| 621 |
+
|
| 622 |
+
for layer in self.convs:
|
| 623 |
+
x = layer(x)
|
| 624 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
| 625 |
+
fmap.append(x)
|
| 626 |
+
x = self.conv_post(x)
|
| 627 |
+
fmap.append(x)
|
| 628 |
+
x = torch.flatten(x, 1, -1)
|
| 629 |
+
|
| 630 |
+
return x, fmap
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
class MultiPeriodDiscriminator(torch.nn.Module):
|
| 634 |
+
def __init__(self, use_spectral_norm=False):
|
| 635 |
+
super(MultiPeriodDiscriminator, self).__init__()
|
| 636 |
+
periods = [2, 3, 5, 7, 11]
|
| 637 |
+
|
| 638 |
+
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
|
| 639 |
+
discs = discs + [
|
| 640 |
+
DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
|
| 641 |
+
]
|
| 642 |
+
self.discriminators = nn.ModuleList(discs)
|
| 643 |
+
|
| 644 |
+
def forward(self, y, y_hat):
|
| 645 |
+
y_d_rs = []
|
| 646 |
+
y_d_gs = []
|
| 647 |
+
fmap_rs = []
|
| 648 |
+
fmap_gs = []
|
| 649 |
+
for i, d in enumerate(self.discriminators):
|
| 650 |
+
y_d_r, fmap_r = d(y)
|
| 651 |
+
y_d_g, fmap_g = d(y_hat)
|
| 652 |
+
y_d_rs.append(y_d_r)
|
| 653 |
+
y_d_gs.append(y_d_g)
|
| 654 |
+
fmap_rs.append(fmap_r)
|
| 655 |
+
fmap_gs.append(fmap_g)
|
| 656 |
+
|
| 657 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
| 658 |
+
|
| 659 |
+
|
| 660 |
+
class WavLMDiscriminator(nn.Module):
|
| 661 |
+
"""docstring for Discriminator."""
|
| 662 |
+
|
| 663 |
+
def __init__(
|
| 664 |
+
self, slm_hidden=768, slm_layers=13, initial_channel=64, use_spectral_norm=False
|
| 665 |
+
):
|
| 666 |
+
super(WavLMDiscriminator, self).__init__()
|
| 667 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
| 668 |
+
self.pre = norm_f(
|
| 669 |
+
Conv1d(slm_hidden * slm_layers, initial_channel, 1, 1, padding=0)
|
| 670 |
+
)
|
| 671 |
+
|
| 672 |
+
self.convs = nn.ModuleList(
|
| 673 |
+
[
|
| 674 |
+
norm_f(
|
| 675 |
+
nn.Conv1d(
|
| 676 |
+
initial_channel, initial_channel * 2, kernel_size=5, padding=2
|
| 677 |
+
)
|
| 678 |
+
),
|
| 679 |
+
norm_f(
|
| 680 |
+
nn.Conv1d(
|
| 681 |
+
initial_channel * 2,
|
| 682 |
+
initial_channel * 4,
|
| 683 |
+
kernel_size=5,
|
| 684 |
+
padding=2,
|
| 685 |
+
)
|
| 686 |
+
),
|
| 687 |
+
norm_f(
|
| 688 |
+
nn.Conv1d(initial_channel * 4, initial_channel * 4, 5, 1, padding=2)
|
| 689 |
+
),
|
| 690 |
+
]
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
self.conv_post = norm_f(Conv1d(initial_channel * 4, 1, 3, 1, padding=1))
|
| 694 |
+
|
| 695 |
+
def forward(self, x):
|
| 696 |
+
x = self.pre(x)
|
| 697 |
+
|
| 698 |
+
fmap = []
|
| 699 |
+
for l in self.convs:
|
| 700 |
+
x = l(x)
|
| 701 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
| 702 |
+
fmap.append(x)
|
| 703 |
+
x = self.conv_post(x)
|
| 704 |
+
x = torch.flatten(x, 1, -1)
|
| 705 |
+
|
| 706 |
+
return x
|
| 707 |
+
|
| 708 |
+
|
| 709 |
+
class ReferenceEncoder(nn.Module):
|
| 710 |
+
"""
|
| 711 |
+
inputs --- [N, Ty/r, n_mels*r] mels
|
| 712 |
+
outputs --- [N, ref_enc_gru_size]
|
| 713 |
+
"""
|
| 714 |
+
|
| 715 |
+
def __init__(self, spec_channels, gin_channels=0):
|
| 716 |
+
super().__init__()
|
| 717 |
+
self.spec_channels = spec_channels
|
| 718 |
+
ref_enc_filters = [32, 32, 64, 64, 128, 128]
|
| 719 |
+
K = len(ref_enc_filters)
|
| 720 |
+
filters = [1] + ref_enc_filters
|
| 721 |
+
convs = [
|
| 722 |
+
weight_norm(
|
| 723 |
+
nn.Conv2d(
|
| 724 |
+
in_channels=filters[i],
|
| 725 |
+
out_channels=filters[i + 1],
|
| 726 |
+
kernel_size=(3, 3),
|
| 727 |
+
stride=(2, 2),
|
| 728 |
+
padding=(1, 1),
|
| 729 |
+
)
|
| 730 |
+
)
|
| 731 |
+
for i in range(K)
|
| 732 |
+
]
|
| 733 |
+
self.convs = nn.ModuleList(convs)
|
| 734 |
+
# self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)]) # noqa: E501
|
| 735 |
+
|
| 736 |
+
out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
|
| 737 |
+
self.gru = nn.GRU(
|
| 738 |
+
input_size=ref_enc_filters[-1] * out_channels,
|
| 739 |
+
hidden_size=256 // 2,
|
| 740 |
+
batch_first=True,
|
| 741 |
+
)
|
| 742 |
+
self.proj = nn.Linear(128, gin_channels)
|
| 743 |
+
|
| 744 |
+
def forward(self, inputs, mask=None):
|
| 745 |
+
N = inputs.size(0)
|
| 746 |
+
out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
|
| 747 |
+
for conv in self.convs:
|
| 748 |
+
out = conv(out)
|
| 749 |
+
# out = wn(out)
|
| 750 |
+
out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
|
| 751 |
+
|
| 752 |
+
out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
|
| 753 |
+
T = out.size(1)
|
| 754 |
+
N = out.size(0)
|
| 755 |
+
out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
|
| 756 |
+
|
| 757 |
+
self.gru.flatten_parameters()
|
| 758 |
+
memory, out = self.gru(out) # out --- [1, N, 128]
|
| 759 |
+
|
| 760 |
+
return self.proj(out.squeeze(0))
|
| 761 |
+
|
| 762 |
+
def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
|
| 763 |
+
for i in range(n_convs):
|
| 764 |
+
L = (L - kernel_size + 2 * pad) // stride + 1
|
| 765 |
+
return L
|
| 766 |
+
|
| 767 |
+
|
| 768 |
+
class SynthesizerTrn(nn.Module):
|
| 769 |
+
"""
|
| 770 |
+
Synthesizer for Training
|
| 771 |
+
"""
|
| 772 |
+
|
| 773 |
+
def __init__(
|
| 774 |
+
self,
|
| 775 |
+
n_vocab,
|
| 776 |
+
spec_channels,
|
| 777 |
+
segment_size,
|
| 778 |
+
inter_channels,
|
| 779 |
+
hidden_channels,
|
| 780 |
+
filter_channels,
|
| 781 |
+
n_heads,
|
| 782 |
+
n_layers,
|
| 783 |
+
kernel_size,
|
| 784 |
+
p_dropout,
|
| 785 |
+
resblock,
|
| 786 |
+
resblock_kernel_sizes,
|
| 787 |
+
resblock_dilation_sizes,
|
| 788 |
+
upsample_rates,
|
| 789 |
+
upsample_initial_channel,
|
| 790 |
+
upsample_kernel_sizes,
|
| 791 |
+
n_speakers=256,
|
| 792 |
+
gin_channels=256,
|
| 793 |
+
use_sdp=True,
|
| 794 |
+
n_flow_layer=4,
|
| 795 |
+
n_layers_trans_flow=4,
|
| 796 |
+
flow_share_parameter=False,
|
| 797 |
+
use_transformer_flow=True,
|
| 798 |
+
**kwargs
|
| 799 |
+
):
|
| 800 |
+
super().__init__()
|
| 801 |
+
self.n_vocab = n_vocab
|
| 802 |
+
self.spec_channels = spec_channels
|
| 803 |
+
self.inter_channels = inter_channels
|
| 804 |
+
self.hidden_channels = hidden_channels
|
| 805 |
+
self.filter_channels = filter_channels
|
| 806 |
+
self.n_heads = n_heads
|
| 807 |
+
self.n_layers = n_layers
|
| 808 |
+
self.kernel_size = kernel_size
|
| 809 |
+
self.p_dropout = p_dropout
|
| 810 |
+
self.resblock = resblock
|
| 811 |
+
self.resblock_kernel_sizes = resblock_kernel_sizes
|
| 812 |
+
self.resblock_dilation_sizes = resblock_dilation_sizes
|
| 813 |
+
self.upsample_rates = upsample_rates
|
| 814 |
+
self.upsample_initial_channel = upsample_initial_channel
|
| 815 |
+
self.upsample_kernel_sizes = upsample_kernel_sizes
|
| 816 |
+
self.segment_size = segment_size
|
| 817 |
+
self.n_speakers = n_speakers
|
| 818 |
+
self.gin_channels = gin_channels
|
| 819 |
+
self.n_layers_trans_flow = n_layers_trans_flow
|
| 820 |
+
self.use_spk_conditioned_encoder = kwargs.get(
|
| 821 |
+
"use_spk_conditioned_encoder", True
|
| 822 |
+
)
|
| 823 |
+
self.use_sdp = use_sdp
|
| 824 |
+
self.use_noise_scaled_mas = kwargs.get("use_noise_scaled_mas", False)
|
| 825 |
+
self.mas_noise_scale_initial = kwargs.get("mas_noise_scale_initial", 0.01)
|
| 826 |
+
self.noise_scale_delta = kwargs.get("noise_scale_delta", 2e-6)
|
| 827 |
+
self.current_mas_noise_scale = self.mas_noise_scale_initial
|
| 828 |
+
if self.use_spk_conditioned_encoder and gin_channels > 0:
|
| 829 |
+
self.enc_gin_channels = gin_channels
|
| 830 |
+
self.enc_p = TextEncoder(
|
| 831 |
+
n_vocab,
|
| 832 |
+
inter_channels,
|
| 833 |
+
hidden_channels,
|
| 834 |
+
filter_channels,
|
| 835 |
+
n_heads,
|
| 836 |
+
n_layers,
|
| 837 |
+
kernel_size,
|
| 838 |
+
p_dropout,
|
| 839 |
+
gin_channels=self.enc_gin_channels,
|
| 840 |
+
)
|
| 841 |
+
self.dec = Generator(
|
| 842 |
+
inter_channels,
|
| 843 |
+
resblock,
|
| 844 |
+
resblock_kernel_sizes,
|
| 845 |
+
resblock_dilation_sizes,
|
| 846 |
+
upsample_rates,
|
| 847 |
+
upsample_initial_channel,
|
| 848 |
+
upsample_kernel_sizes,
|
| 849 |
+
gin_channels=gin_channels,
|
| 850 |
+
)
|
| 851 |
+
self.enc_q = PosteriorEncoder(
|
| 852 |
+
spec_channels,
|
| 853 |
+
inter_channels,
|
| 854 |
+
hidden_channels,
|
| 855 |
+
5,
|
| 856 |
+
1,
|
| 857 |
+
16,
|
| 858 |
+
gin_channels=gin_channels,
|
| 859 |
+
)
|
| 860 |
+
if use_transformer_flow:
|
| 861 |
+
self.flow = TransformerCouplingBlock(
|
| 862 |
+
inter_channels,
|
| 863 |
+
hidden_channels,
|
| 864 |
+
filter_channels,
|
| 865 |
+
n_heads,
|
| 866 |
+
n_layers_trans_flow,
|
| 867 |
+
5,
|
| 868 |
+
p_dropout,
|
| 869 |
+
n_flow_layer,
|
| 870 |
+
gin_channels=gin_channels,
|
| 871 |
+
share_parameter=flow_share_parameter,
|
| 872 |
+
)
|
| 873 |
+
else:
|
| 874 |
+
self.flow = ResidualCouplingBlock(
|
| 875 |
+
inter_channels,
|
| 876 |
+
hidden_channels,
|
| 877 |
+
5,
|
| 878 |
+
1,
|
| 879 |
+
n_flow_layer,
|
| 880 |
+
gin_channels=gin_channels,
|
| 881 |
+
)
|
| 882 |
+
self.sdp = StochasticDurationPredictor(
|
| 883 |
+
hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels
|
| 884 |
+
)
|
| 885 |
+
self.dp = DurationPredictor(
|
| 886 |
+
hidden_channels, 256, 3, 0.5, gin_channels=gin_channels
|
| 887 |
+
)
|
| 888 |
+
|
| 889 |
+
if n_speakers >= 1:
|
| 890 |
+
self.emb_g = nn.Embedding(n_speakers, gin_channels)
|
| 891 |
+
else:
|
| 892 |
+
self.ref_enc = ReferenceEncoder(spec_channels, gin_channels)
|
| 893 |
+
|
| 894 |
+
def export_onnx(
|
| 895 |
+
self,
|
| 896 |
+
path,
|
| 897 |
+
max_len=None,
|
| 898 |
+
sdp_ratio=0,
|
| 899 |
+
y=None,
|
| 900 |
+
):
|
| 901 |
+
noise_scale = 0.667
|
| 902 |
+
length_scale = 1
|
| 903 |
+
noise_scale_w = 0.8
|
| 904 |
+
x = (
|
| 905 |
+
torch.LongTensor(
|
| 906 |
+
[
|
| 907 |
+
0,
|
| 908 |
+
97,
|
| 909 |
+
0,
|
| 910 |
+
8,
|
| 911 |
+
0,
|
| 912 |
+
78,
|
| 913 |
+
0,
|
| 914 |
+
8,
|
| 915 |
+
0,
|
| 916 |
+
76,
|
| 917 |
+
0,
|
| 918 |
+
37,
|
| 919 |
+
0,
|
| 920 |
+
40,
|
| 921 |
+
0,
|
| 922 |
+
97,
|
| 923 |
+
0,
|
| 924 |
+
8,
|
| 925 |
+
0,
|
| 926 |
+
23,
|
| 927 |
+
0,
|
| 928 |
+
8,
|
| 929 |
+
0,
|
| 930 |
+
74,
|
| 931 |
+
0,
|
| 932 |
+
26,
|
| 933 |
+
0,
|
| 934 |
+
104,
|
| 935 |
+
0,
|
| 936 |
+
]
|
| 937 |
+
)
|
| 938 |
+
.unsqueeze(0)
|
| 939 |
+
.cpu()
|
| 940 |
+
)
|
| 941 |
+
tone = torch.zeros_like(x).cpu()
|
| 942 |
+
language = torch.zeros_like(x).cpu()
|
| 943 |
+
x_lengths = torch.LongTensor([x.shape[1]]).cpu()
|
| 944 |
+
sid = torch.LongTensor([0]).cpu()
|
| 945 |
+
bert = torch.randn(size=(x.shape[1], 1024)).cpu()
|
| 946 |
+
ja_bert = torch.randn(size=(x.shape[1], 1024)).cpu()
|
| 947 |
+
en_bert = torch.randn(size=(x.shape[1], 1024)).cpu()
|
| 948 |
+
|
| 949 |
+
if self.n_speakers > 0:
|
| 950 |
+
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
| 951 |
+
torch.onnx.export(
|
| 952 |
+
self.emb_g,
|
| 953 |
+
(sid),
|
| 954 |
+
f"onnx/{path}/{path}_emb.onnx",
|
| 955 |
+
input_names=["sid"],
|
| 956 |
+
output_names=["g"],
|
| 957 |
+
verbose=True,
|
| 958 |
+
)
|
| 959 |
+
else:
|
| 960 |
+
g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
|
| 961 |
+
|
| 962 |
+
torch.onnx.export(
|
| 963 |
+
self.enc_p,
|
| 964 |
+
(x, x_lengths, tone, language, bert, ja_bert, en_bert, g),
|
| 965 |
+
f"onnx/{path}/{path}_enc_p.onnx",
|
| 966 |
+
input_names=[
|
| 967 |
+
"x",
|
| 968 |
+
"x_lengths",
|
| 969 |
+
"t",
|
| 970 |
+
"language",
|
| 971 |
+
"bert_0",
|
| 972 |
+
"bert_1",
|
| 973 |
+
"bert_2",
|
| 974 |
+
"g",
|
| 975 |
+
],
|
| 976 |
+
output_names=["xout", "m_p", "logs_p", "x_mask"],
|
| 977 |
+
dynamic_axes={
|
| 978 |
+
"x": [0, 1],
|
| 979 |
+
"t": [0, 1],
|
| 980 |
+
"language": [0, 1],
|
| 981 |
+
"bert_0": [0],
|
| 982 |
+
"bert_1": [0],
|
| 983 |
+
"bert_2": [0],
|
| 984 |
+
"xout": [0, 2],
|
| 985 |
+
"m_p": [0, 2],
|
| 986 |
+
"logs_p": [0, 2],
|
| 987 |
+
"x_mask": [0, 2],
|
| 988 |
+
},
|
| 989 |
+
verbose=True,
|
| 990 |
+
opset_version=16,
|
| 991 |
+
)
|
| 992 |
+
|
| 993 |
+
x, m_p, logs_p, x_mask = self.enc_p(
|
| 994 |
+
x, x_lengths, tone, language, bert, ja_bert, en_bert, g
|
| 995 |
+
)
|
| 996 |
+
|
| 997 |
+
zinput = (
|
| 998 |
+
torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
|
| 999 |
+
* noise_scale_w
|
| 1000 |
+
)
|
| 1001 |
+
torch.onnx.export(
|
| 1002 |
+
self.sdp,
|
| 1003 |
+
(x, x_mask, zinput, g),
|
| 1004 |
+
f"onnx/{path}/{path}_sdp.onnx",
|
| 1005 |
+
input_names=["x", "x_mask", "zin", "g"],
|
| 1006 |
+
output_names=["logw"],
|
| 1007 |
+
dynamic_axes={"x": [0, 2], "x_mask": [0, 2], "zin": [0, 2], "logw": [0, 2]},
|
| 1008 |
+
verbose=True,
|
| 1009 |
+
)
|
| 1010 |
+
torch.onnx.export(
|
| 1011 |
+
self.dp,
|
| 1012 |
+
(x, x_mask, g),
|
| 1013 |
+
f"onnx/{path}/{path}_dp.onnx",
|
| 1014 |
+
input_names=["x", "x_mask", "g"],
|
| 1015 |
+
output_names=["logw"],
|
| 1016 |
+
dynamic_axes={"x": [0, 2], "x_mask": [0, 2], "logw": [0, 2]},
|
| 1017 |
+
verbose=True,
|
| 1018 |
+
)
|
| 1019 |
+
logw = self.sdp(x, x_mask, zinput, g=g) * (sdp_ratio) + self.dp(
|
| 1020 |
+
x, x_mask, g=g
|
| 1021 |
+
) * (1 - sdp_ratio)
|
| 1022 |
+
w = torch.exp(logw) * x_mask * length_scale
|
| 1023 |
+
w_ceil = torch.ceil(w)
|
| 1024 |
+
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
| 1025 |
+
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(
|
| 1026 |
+
x_mask.dtype
|
| 1027 |
+
)
|
| 1028 |
+
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
|
| 1029 |
+
attn = commons.generate_path(w_ceil, attn_mask)
|
| 1030 |
+
|
| 1031 |
+
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(
|
| 1032 |
+
1, 2
|
| 1033 |
+
) # [b, t', t], [b, t, d] -> [b, d, t']
|
| 1034 |
+
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(
|
| 1035 |
+
1, 2
|
| 1036 |
+
) # [b, t', t], [b, t, d] -> [b, d, t']
|
| 1037 |
+
|
| 1038 |
+
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
| 1039 |
+
torch.onnx.export(
|
| 1040 |
+
self.flow,
|
| 1041 |
+
(z_p, y_mask, g),
|
| 1042 |
+
f"onnx/{path}/{path}_flow.onnx",
|
| 1043 |
+
input_names=["z_p", "y_mask", "g"],
|
| 1044 |
+
output_names=["z"],
|
| 1045 |
+
dynamic_axes={"z_p": [0, 2], "y_mask": [0, 2], "z": [0, 2]},
|
| 1046 |
+
verbose=True,
|
| 1047 |
+
)
|
| 1048 |
+
|
| 1049 |
+
z = self.flow(z_p, y_mask, g=g, reverse=True)
|
| 1050 |
+
z_in = (z * y_mask)[:, :, :max_len]
|
| 1051 |
+
|
| 1052 |
+
torch.onnx.export(
|
| 1053 |
+
self.dec,
|
| 1054 |
+
(z_in, g),
|
| 1055 |
+
f"onnx/{path}/{path}_dec.onnx",
|
| 1056 |
+
input_names=["z_in", "g"],
|
| 1057 |
+
output_names=["o"],
|
| 1058 |
+
dynamic_axes={"z_in": [0, 2], "o": [0, 2]},
|
| 1059 |
+
verbose=True,
|
| 1060 |
+
)
|
| 1061 |
+
o = self.dec((z * y_mask)[:, :, :max_len], g=g)
|
onnx_modules/V230/text/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .symbols import *
|
onnx_modules/V230/text/symbols.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
punctuation = ["!", "?", "…", ",", ".", "'", "-"]
|
| 2 |
+
pu_symbols = punctuation + ["SP", "UNK"]
|
| 3 |
+
pad = "_"
|
| 4 |
+
|
| 5 |
+
# chinese
|
| 6 |
+
zh_symbols = [
|
| 7 |
+
"E",
|
| 8 |
+
"En",
|
| 9 |
+
"a",
|
| 10 |
+
"ai",
|
| 11 |
+
"an",
|
| 12 |
+
"ang",
|
| 13 |
+
"ao",
|
| 14 |
+
"b",
|
| 15 |
+
"c",
|
| 16 |
+
"ch",
|
| 17 |
+
"d",
|
| 18 |
+
"e",
|
| 19 |
+
"ei",
|
| 20 |
+
"en",
|
| 21 |
+
"eng",
|
| 22 |
+
"er",
|
| 23 |
+
"f",
|
| 24 |
+
"g",
|
| 25 |
+
"h",
|
| 26 |
+
"i",
|
| 27 |
+
"i0",
|
| 28 |
+
"ia",
|
| 29 |
+
"ian",
|
| 30 |
+
"iang",
|
| 31 |
+
"iao",
|
| 32 |
+
"ie",
|
| 33 |
+
"in",
|
| 34 |
+
"ing",
|
| 35 |
+
"iong",
|
| 36 |
+
"ir",
|
| 37 |
+
"iu",
|
| 38 |
+
"j",
|
| 39 |
+
"k",
|
| 40 |
+
"l",
|
| 41 |
+
"m",
|
| 42 |
+
"n",
|
| 43 |
+
"o",
|
| 44 |
+
"ong",
|
| 45 |
+
"ou",
|
| 46 |
+
"p",
|
| 47 |
+
"q",
|
| 48 |
+
"r",
|
| 49 |
+
"s",
|
| 50 |
+
"sh",
|
| 51 |
+
"t",
|
| 52 |
+
"u",
|
| 53 |
+
"ua",
|
| 54 |
+
"uai",
|
| 55 |
+
"uan",
|
| 56 |
+
"uang",
|
| 57 |
+
"ui",
|
| 58 |
+
"un",
|
| 59 |
+
"uo",
|
| 60 |
+
"v",
|
| 61 |
+
"van",
|
| 62 |
+
"ve",
|
| 63 |
+
"vn",
|
| 64 |
+
"w",
|
| 65 |
+
"x",
|
| 66 |
+
"y",
|
| 67 |
+
"z",
|
| 68 |
+
"zh",
|
| 69 |
+
"AA",
|
| 70 |
+
"EE",
|
| 71 |
+
"OO",
|
| 72 |
+
]
|
| 73 |
+
num_zh_tones = 6
|
| 74 |
+
|
| 75 |
+
# japanese
|
| 76 |
+
ja_symbols = [
|
| 77 |
+
"N",
|
| 78 |
+
"a",
|
| 79 |
+
"a:",
|
| 80 |
+
"b",
|
| 81 |
+
"by",
|
| 82 |
+
"ch",
|
| 83 |
+
"d",
|
| 84 |
+
"dy",
|
| 85 |
+
"e",
|
| 86 |
+
"e:",
|
| 87 |
+
"f",
|
| 88 |
+
"g",
|
| 89 |
+
"gy",
|
| 90 |
+
"h",
|
| 91 |
+
"hy",
|
| 92 |
+
"i",
|
| 93 |
+
"i:",
|
| 94 |
+
"j",
|
| 95 |
+
"k",
|
| 96 |
+
"ky",
|
| 97 |
+
"m",
|
| 98 |
+
"my",
|
| 99 |
+
"n",
|
| 100 |
+
"ny",
|
| 101 |
+
"o",
|
| 102 |
+
"o:",
|
| 103 |
+
"p",
|
| 104 |
+
"py",
|
| 105 |
+
"q",
|
| 106 |
+
"r",
|
| 107 |
+
"ry",
|
| 108 |
+
"s",
|
| 109 |
+
"sh",
|
| 110 |
+
"t",
|
| 111 |
+
"ts",
|
| 112 |
+
"ty",
|
| 113 |
+
"u",
|
| 114 |
+
"u:",
|
| 115 |
+
"w",
|
| 116 |
+
"y",
|
| 117 |
+
"z",
|
| 118 |
+
"zy",
|
| 119 |
+
]
|
| 120 |
+
num_ja_tones = 2
|
| 121 |
+
|
| 122 |
+
# English
|
| 123 |
+
en_symbols = [
|
| 124 |
+
"aa",
|
| 125 |
+
"ae",
|
| 126 |
+
"ah",
|
| 127 |
+
"ao",
|
| 128 |
+
"aw",
|
| 129 |
+
"ay",
|
| 130 |
+
"b",
|
| 131 |
+
"ch",
|
| 132 |
+
"d",
|
| 133 |
+
"dh",
|
| 134 |
+
"eh",
|
| 135 |
+
"er",
|
| 136 |
+
"ey",
|
| 137 |
+
"f",
|
| 138 |
+
"g",
|
| 139 |
+
"hh",
|
| 140 |
+
"ih",
|
| 141 |
+
"iy",
|
| 142 |
+
"jh",
|
| 143 |
+
"k",
|
| 144 |
+
"l",
|
| 145 |
+
"m",
|
| 146 |
+
"n",
|
| 147 |
+
"ng",
|
| 148 |
+
"ow",
|
| 149 |
+
"oy",
|
| 150 |
+
"p",
|
| 151 |
+
"r",
|
| 152 |
+
"s",
|
| 153 |
+
"sh",
|
| 154 |
+
"t",
|
| 155 |
+
"th",
|
| 156 |
+
"uh",
|
| 157 |
+
"uw",
|
| 158 |
+
"V",
|
| 159 |
+
"w",
|
| 160 |
+
"y",
|
| 161 |
+
"z",
|
| 162 |
+
"zh",
|
| 163 |
+
]
|
| 164 |
+
num_en_tones = 4
|
| 165 |
+
|
| 166 |
+
# combine all symbols
|
| 167 |
+
normal_symbols = sorted(set(zh_symbols + ja_symbols + en_symbols))
|
| 168 |
+
symbols = [pad] + normal_symbols + pu_symbols
|
| 169 |
+
sil_phonemes_ids = [symbols.index(i) for i in pu_symbols]
|
| 170 |
+
|
| 171 |
+
# combine all tones
|
| 172 |
+
num_tones = num_zh_tones + num_ja_tones + num_en_tones
|
| 173 |
+
|
| 174 |
+
# language maps
|
| 175 |
+
language_id_map = {"ZH": 0, "JP": 1, "EN": 2}
|
| 176 |
+
num_languages = len(language_id_map.keys())
|
| 177 |
+
|
| 178 |
+
language_tone_start_map = {
|
| 179 |
+
"ZH": 0,
|
| 180 |
+
"JP": num_zh_tones,
|
| 181 |
+
"EN": num_zh_tones + num_ja_tones,
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
if __name__ == "__main__":
|
| 185 |
+
a = set(zh_symbols)
|
| 186 |
+
b = set(en_symbols)
|
| 187 |
+
print(sorted(a & b))
|
onnx_modules/V230_OnnxInference/__init__.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import onnxruntime as ort
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def convert_pad_shape(pad_shape):
|
| 6 |
+
layer = pad_shape[::-1]
|
| 7 |
+
pad_shape = [item for sublist in layer for item in sublist]
|
| 8 |
+
return pad_shape
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def sequence_mask(length, max_length=None):
|
| 12 |
+
if max_length is None:
|
| 13 |
+
max_length = length.max()
|
| 14 |
+
x = np.arange(max_length, dtype=length.dtype)
|
| 15 |
+
return np.expand_dims(x, 0) < np.expand_dims(length, 1)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def generate_path(duration, mask):
|
| 19 |
+
"""
|
| 20 |
+
duration: [b, 1, t_x]
|
| 21 |
+
mask: [b, 1, t_y, t_x]
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
b, _, t_y, t_x = mask.shape
|
| 25 |
+
cum_duration = np.cumsum(duration, -1)
|
| 26 |
+
|
| 27 |
+
cum_duration_flat = cum_duration.reshape(b * t_x)
|
| 28 |
+
path = sequence_mask(cum_duration_flat, t_y)
|
| 29 |
+
path = path.reshape(b, t_x, t_y)
|
| 30 |
+
path = path ^ np.pad(path, ((0, 0), (1, 0), (0, 0)))[:, :-1]
|
| 31 |
+
path = np.expand_dims(path, 1).transpose(0, 1, 3, 2)
|
| 32 |
+
return path
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class OnnxInferenceSession:
|
| 36 |
+
def __init__(self, path, Providers=["CPUExecutionProvider"]):
|
| 37 |
+
self.enc = ort.InferenceSession(path["enc"], providers=Providers)
|
| 38 |
+
self.emb_g = ort.InferenceSession(path["emb_g"], providers=Providers)
|
| 39 |
+
self.dp = ort.InferenceSession(path["dp"], providers=Providers)
|
| 40 |
+
self.sdp = ort.InferenceSession(path["sdp"], providers=Providers)
|
| 41 |
+
self.flow = ort.InferenceSession(path["flow"], providers=Providers)
|
| 42 |
+
self.dec = ort.InferenceSession(path["dec"], providers=Providers)
|
| 43 |
+
|
| 44 |
+
def __call__(
|
| 45 |
+
self,
|
| 46 |
+
seq,
|
| 47 |
+
tone,
|
| 48 |
+
language,
|
| 49 |
+
bert_zh,
|
| 50 |
+
bert_jp,
|
| 51 |
+
bert_en,
|
| 52 |
+
sid,
|
| 53 |
+
seed=114514,
|
| 54 |
+
seq_noise_scale=0.8,
|
| 55 |
+
sdp_noise_scale=0.6,
|
| 56 |
+
length_scale=1.0,
|
| 57 |
+
sdp_ratio=0.0,
|
| 58 |
+
):
|
| 59 |
+
if seq.ndim == 1:
|
| 60 |
+
seq = np.expand_dims(seq, 0)
|
| 61 |
+
if tone.ndim == 1:
|
| 62 |
+
tone = np.expand_dims(tone, 0)
|
| 63 |
+
if language.ndim == 1:
|
| 64 |
+
language = np.expand_dims(language, 0)
|
| 65 |
+
assert(seq.ndim == 2,tone.ndim == 2,language.ndim == 2)
|
| 66 |
+
g = self.emb_g.run(
|
| 67 |
+
None,
|
| 68 |
+
{
|
| 69 |
+
"sid": sid.astype(np.int64),
|
| 70 |
+
},
|
| 71 |
+
)[0]
|
| 72 |
+
g = np.expand_dims(g, -1)
|
| 73 |
+
enc_rtn = self.enc.run(
|
| 74 |
+
None,
|
| 75 |
+
{
|
| 76 |
+
"x": seq.astype(np.int64),
|
| 77 |
+
"t": tone.astype(np.int64),
|
| 78 |
+
"language": language.astype(np.int64),
|
| 79 |
+
"bert_0": bert_zh.astype(np.float32),
|
| 80 |
+
"bert_1": bert_jp.astype(np.float32),
|
| 81 |
+
"bert_2": bert_en.astype(np.float32),
|
| 82 |
+
"g": g.astype(np.float32),
|
| 83 |
+
},
|
| 84 |
+
)
|
| 85 |
+
x, m_p, logs_p, x_mask = enc_rtn[0], enc_rtn[1], enc_rtn[2], enc_rtn[3]
|
| 86 |
+
np.random.seed(seed)
|
| 87 |
+
zinput = np.random.randn(x.shape[0], 2, x.shape[2]) * sdp_noise_scale
|
| 88 |
+
logw = self.sdp.run(
|
| 89 |
+
None, {"x": x, "x_mask": x_mask, "zin": zinput.astype(np.float32), "g": g}
|
| 90 |
+
)[0] * (sdp_ratio) + self.dp.run(None, {"x": x, "x_mask": x_mask, "g": g})[
|
| 91 |
+
0
|
| 92 |
+
] * (
|
| 93 |
+
1 - sdp_ratio
|
| 94 |
+
)
|
| 95 |
+
w = np.exp(logw) * x_mask * length_scale
|
| 96 |
+
w_ceil = np.ceil(w)
|
| 97 |
+
y_lengths = np.clip(np.sum(w_ceil, (1, 2)), a_min=1.0, a_max=100000).astype(
|
| 98 |
+
np.int64
|
| 99 |
+
)
|
| 100 |
+
y_mask = np.expand_dims(sequence_mask(y_lengths, None), 1)
|
| 101 |
+
attn_mask = np.expand_dims(x_mask, 2) * np.expand_dims(y_mask, -1)
|
| 102 |
+
attn = generate_path(w_ceil, attn_mask)
|
| 103 |
+
m_p = np.matmul(attn.squeeze(1), m_p.transpose(0, 2, 1)).transpose(
|
| 104 |
+
0, 2, 1
|
| 105 |
+
) # [b, t', t], [b, t, d] -> [b, d, t']
|
| 106 |
+
logs_p = np.matmul(attn.squeeze(1), logs_p.transpose(0, 2, 1)).transpose(
|
| 107 |
+
0, 2, 1
|
| 108 |
+
) # [b, t', t], [b, t, d] -> [b, d, t']
|
| 109 |
+
|
| 110 |
+
z_p = (
|
| 111 |
+
m_p
|
| 112 |
+
+ np.random.randn(m_p.shape[0], m_p.shape[1], m_p.shape[2])
|
| 113 |
+
* np.exp(logs_p)
|
| 114 |
+
* seq_noise_scale
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
z = self.flow.run(
|
| 118 |
+
None,
|
| 119 |
+
{
|
| 120 |
+
"z_p": z_p.astype(np.float32),
|
| 121 |
+
"y_mask": y_mask.astype(np.float32),
|
| 122 |
+
"g": g,
|
| 123 |
+
},
|
| 124 |
+
)[0]
|
| 125 |
+
|
| 126 |
+
return self.dec.run(None, {"z_in": z.astype(np.float32), "g": g})[0]
|
onnx_modules/__init__.py
CHANGED
|
@@ -1,14 +1,21 @@
|
|
| 1 |
-
from utils import get_hparams_from_file, load_checkpoint
|
| 2 |
import json
|
| 3 |
|
| 4 |
|
| 5 |
-
def export_onnx(export_path, model_path, config_path):
|
| 6 |
hps = get_hparams_from_file(config_path)
|
| 7 |
version = hps.version[0:3]
|
| 8 |
-
if version == "2.0":
|
| 9 |
from .V200 import SynthesizerTrn, symbols
|
| 10 |
-
elif version == "2.1":
|
| 11 |
from .V210 import SynthesizerTrn, symbols
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
net_g = SynthesizerTrn(
|
| 13 |
len(symbols),
|
| 14 |
hps.data.filter_length // 2 + 1,
|
|
@@ -41,6 +48,7 @@ def export_onnx(export_path, model_path, config_path):
|
|
| 41 |
"deberta-v2-large-japanese",
|
| 42 |
"bert-base-japanese-v3",
|
| 43 |
],
|
|
|
|
| 44 |
}
|
| 45 |
|
| 46 |
with open(f"onnx/{export_path}.json", "w") as MoeVsConfFile:
|
|
|
|
| 1 |
+
from utils import get_hparams_from_file, load_checkpoint
|
| 2 |
import json
|
| 3 |
|
| 4 |
|
| 5 |
+
def export_onnx(export_path, model_path, config_path, novq, dev):
|
| 6 |
hps = get_hparams_from_file(config_path)
|
| 7 |
version = hps.version[0:3]
|
| 8 |
+
if version == "2.0" or (version == "2.1" and novq):
|
| 9 |
from .V200 import SynthesizerTrn, symbols
|
| 10 |
+
elif version == "2.1" and (not novq):
|
| 11 |
from .V210 import SynthesizerTrn, symbols
|
| 12 |
+
elif version == "2.2":
|
| 13 |
+
if novq and dev:
|
| 14 |
+
from .V220_novq_dev import SynthesizerTrn, symbols
|
| 15 |
+
else:
|
| 16 |
+
from .V220 import SynthesizerTrn, symbols
|
| 17 |
+
elif version == "2.3":
|
| 18 |
+
from .V230 import SynthesizerTrn, symbols
|
| 19 |
net_g = SynthesizerTrn(
|
| 20 |
len(symbols),
|
| 21 |
hps.data.filter_length // 2 + 1,
|
|
|
|
| 48 |
"deberta-v2-large-japanese",
|
| 49 |
"bert-base-japanese-v3",
|
| 50 |
],
|
| 51 |
+
"Clap": "clap-htsat-fused",
|
| 52 |
}
|
| 53 |
|
| 54 |
with open(f"onnx/{export_path}.json", "w") as MoeVsConfFile:
|
re_matching.py
CHANGED
|
@@ -44,7 +44,6 @@ def text_matching(text: str) -> list:
|
|
| 44 |
result = []
|
| 45 |
for speaker, dialogue in matches:
|
| 46 |
result.append(extract_language_and_text_updated(speaker, dialogue))
|
| 47 |
-
print(result)
|
| 48 |
return result
|
| 49 |
|
| 50 |
|
|
|
|
| 44 |
result = []
|
| 45 |
for speaker, dialogue in matches:
|
| 46 |
result.append(extract_language_and_text_updated(speaker, dialogue))
|
|
|
|
| 47 |
return result
|
| 48 |
|
| 49 |
|
requirements.txt
CHANGED
|
@@ -11,7 +11,7 @@ jieba
|
|
| 11 |
transformers
|
| 12 |
pypinyin
|
| 13 |
cn2an
|
| 14 |
-
gradio==3.
|
| 15 |
av
|
| 16 |
mecab-python3
|
| 17 |
loguru
|
|
@@ -21,8 +21,7 @@ fugashi
|
|
| 21 |
num2words
|
| 22 |
PyYAML
|
| 23 |
requests
|
| 24 |
-
pyopenjtalk
|
| 25 |
-
openjtalk; sys_platform != 'linux'
|
| 26 |
jaconv
|
| 27 |
psutil
|
| 28 |
GPUtil
|
|
|
|
| 11 |
transformers
|
| 12 |
pypinyin
|
| 13 |
cn2an
|
| 14 |
+
gradio==3.50.2
|
| 15 |
av
|
| 16 |
mecab-python3
|
| 17 |
loguru
|
|
|
|
| 21 |
num2words
|
| 22 |
PyYAML
|
| 23 |
requests
|
| 24 |
+
pyopenjtalk-prebuilt
|
|
|
|
| 25 |
jaconv
|
| 26 |
psutil
|
| 27 |
GPUtil
|