+
+ #### 此沙盒使用 Huggingface CPU,請預期大於200 秒的推理時間,您可以考慮以下方法加速:
+ 1. 複製這個 Space(僅當執行需要排隊時)
+ 2. 複製至本地GPU執行(請參考[指南](https://huggingface.co/docs/hub/en/spaces-overview))或使用[kaggle](https://www.kaggle.com/code/a24998667/breezyvoice-playground)
+ 3. 複製至本地CPU執行(請參考[指南](https://huggingface.co/docs/hub/en/spaces-overview))
+
+ 為了加快推理速度,g2pw注音標註並未被啟動。
+
+ 免責聲明:此沙盒在一次性容器地端執行,關閉後檔案將遭到刪除。此沙盒不屬於聯發創新基地,聯發創新基地無法獲得任何使用者輸入。"""
+ )
+
+ # All content arranged in a single column
+ with gr.Column():
+ # Configuration Section
+
+
+
+ # Grouping prompt audio inputs and auto speech recognition in one block using Markdown
+ gr.Markdown("### 步驟 1. 音訊樣本輸入 & 音訊樣本文本輸入")
+ gr.Markdown("選擇prompt音訊檔案或錄製prompt音訊,並手動校對自動產生的音訊樣本文本。")
+ prompt_wav_upload = gr.Audio(
+ sources='upload',
+ type='filepath',
+ label='選擇prompt音訊檔案(確保取樣率不低於16khz)'
+ )
+ prompt_wav_record = gr.Audio(
+ sources='microphone',
+ type='filepath',
+ label='錄製prompt音訊檔案'
+ )
+
+ with gr.Blocks():
+ select_which = gr.Radio(["上傳檔案", "麥克風"], label="音訊來源", interactive=True )
+ with gr.Blocks():
+ prompt_text = gr.Textbox(
+ label="音訊樣本文本輸入(此欄位應與音檔內容完全相同)",
+ lines=2,
+ placeholder="音訊樣本文本"
+ )
+
+ # Automatic speech recognition when either prompt audio input changes
+ def a(X):
+ return "上傳檔案"
+ prompt_wav_upload.change(
+ fn=a,#lambda file: "上傳檔案",
+ inputs=[prompt_wav_upload],
+ outputs=select_which
+ )
+
+
+
+
+
+ prompt_wav_record.change(
+ fn=lambda recording: "麥克風",
+ inputs=[prompt_wav_record],
+ outputs=select_which
+ )
+
+ select_which.change(
+ fn=generate_text,
+ inputs=[prompt_wav_upload, prompt_wav_record, select_which],
+ outputs=prompt_text
+ )
+ # select_which.change(
+ # fn=switch_selected,
+ # inputs=[select_which],
+ # outputs= None
+ # )
+ # Input Section: Synthesis Text
+
+ gr.Markdown("### 步驟 2.合成文本輸入")
+ tts_text = gr.Textbox(
+ label="輸入想要合成的文本",
+ lines=2,
+ placeholder="請輸入想要合成的文本...",
+ value="你好,歡迎光臨"
+ )
+
+
+ # Output Section
+ gr.Markdown("### 步驟 3. 合成音訊")
+ # Generation button for audio synthesis (triggered manually)
+
+ with gr.Accordion("進階設定", open=False):
+ seed = gr.Number(value=0, label="隨機推理種子")
+ #seed_button = gr.Button("隨機")
+ seed_button = gr.Button(value="\U0001F3B2生成隨機推理種子\U0001F3B2")
+ speed_factor = 1
+ # speed_factor = gr.Slider(
+ # minimum=0.25,
+ # maximum=4,
+ # step=0.05,
+ # label="語速",
+ # value=1.0,
+ # interactive=True
+ # )
+
+ generate_button = gr.Button("生成音訊")
+ audio_output = gr.Audio(label="合成音訊")
+
+ # Set up callbacks for seed generation and audio synthesis
+ seed_button.click(fn=generate_seed, inputs=[], outputs=seed)
+ generate_button.click(
+ fn=generate_audio,
+ inputs=[tts_text, prompt_text, prompt_wav_upload, prompt_wav_record, seed, select_which],
+ outputs=audio_output
+ )
+
+ demo.queue(max_size=4, default_concurrency_limit=2)
+ demo.launch()
+
+if __name__ == '__main__':
+ cosyvoice = CosyVoice('Splend1dchan/BreezyVoice')
+ asr_pipeline = pipeline(
+ "automatic-speech-recognition",
+ model="openai/whisper-tiny",
+ tokenizer="openai/whisper-tiny",
+ device=0 # Use GPU (if available); set to -1 for CPU
+ )
+ sft_spk = cosyvoice.list_avaliable_spks()
+ prompt_sr, target_sr = 16000, 22050
+ default_data = np.zeros(target_sr)
+ main()
diff --git a/cosyvoice/__init__.py b/cosyvoice/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/cosyvoice/__pycache__/__init__.cpython-310.pyc b/cosyvoice/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..977cc13a136099e28940d51e6eb28e99c980b110
Binary files /dev/null and b/cosyvoice/__pycache__/__init__.cpython-310.pyc differ
diff --git a/cosyvoice/__pycache__/__init__.cpython-38.pyc b/cosyvoice/__pycache__/__init__.cpython-38.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..ab53a14da62922ed1ca6892ea79d4e4c0f5c90f9
Binary files /dev/null and b/cosyvoice/__pycache__/__init__.cpython-38.pyc differ
diff --git a/cosyvoice/bin/inference.py b/cosyvoice/bin/inference.py
new file mode 100755
index 0000000000000000000000000000000000000000..6b777fa1cba925f9786db60b7efa15dcd189adeb
--- /dev/null
+++ b/cosyvoice/bin/inference.py
@@ -0,0 +1,114 @@
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+
+import argparse
+import logging
+logging.getLogger('matplotlib').setLevel(logging.WARNING)
+import os
+
+import torch
+from torch.utils.data import DataLoader
+import torchaudio
+from hyperpyyaml import load_hyperpyyaml
+from tqdm import tqdm
+from cosyvoice.cli.model import CosyVoiceModel
+
+from cosyvoice.dataset.dataset import Dataset
+
+def get_args():
+ parser = argparse.ArgumentParser(description='inference with your model')
+ parser.add_argument('--config', required=True, help='config file')
+ parser.add_argument('--prompt_data', required=True, help='prompt data file')
+ parser.add_argument('--prompt_utt2data', required=True, help='prompt data file')
+ parser.add_argument('--tts_text', required=True, help='tts input file')
+ parser.add_argument('--llm_model', required=True, help='llm model file')
+ parser.add_argument('--flow_model', required=True, help='flow model file')
+ parser.add_argument('--hifigan_model', required=True, help='hifigan model file')
+ parser.add_argument('--gpu',
+ type=int,
+ default=-1,
+ help='gpu id for this rank, -1 for cpu')
+ parser.add_argument('--mode',
+ default='sft',
+ choices=['sft', 'zero_shot'],
+ help='inference mode')
+ parser.add_argument('--result_dir', required=True, help='asr result file')
+ args = parser.parse_args()
+ print(args)
+ return args
+
+
+def main():
+ args = get_args()
+ logging.basicConfig(level=logging.DEBUG,
+ format='%(asctime)s %(levelname)s %(message)s')
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
+
+ # Init cosyvoice models from configs
+ use_cuda = args.gpu >= 0 and torch.cuda.is_available()
+ device = torch.device('cuda' if use_cuda else 'cpu')
+ with open(args.config, 'r') as f:
+ configs = load_hyperpyyaml(f)
+
+ model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
+ model.load(args.llm_model, args.flow_model, args.hifigan_model)
+
+ test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False, tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
+ test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
+
+ del configs
+ os.makedirs(args.result_dir, exist_ok=True)
+ fn = os.path.join(args.result_dir, 'wav.scp')
+ f = open(fn, 'w')
+ with torch.no_grad():
+ for batch_idx, batch in tqdm(enumerate(test_data_loader)):
+ utts = batch["utts"]
+ assert len(utts) == 1, "inference mode only support batchsize 1"
+ text = batch["text"]
+ text_token = batch["text_token"].to(device)
+ text_token_len = batch["text_token_len"].to(device)
+ tts_text = batch["tts_text"]
+ tts_index = batch["tts_index"]
+ tts_text_token = batch["tts_text_token"].to(device)
+ tts_text_token_len = batch["tts_text_token_len"].to(device)
+ speech_token = batch["speech_token"].to(device)
+ speech_token_len = batch["speech_token_len"].to(device)
+ speech_feat = batch["speech_feat"].to(device)
+ speech_feat_len = batch["speech_feat_len"].to(device)
+ utt_embedding = batch["utt_embedding"].to(device)
+ spk_embedding = batch["spk_embedding"].to(device)
+ if args.mode == 'sft':
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
+ 'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding}
+ else:
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
+ 'prompt_text': text_token, 'prompt_text_len': text_token_len,
+ 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
+ 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
+ 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
+ 'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
+ model_output = model.inference(**model_input)
+ tts_key = '{}_{}'.format(utts[0], tts_index[0])
+ tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
+ torchaudio.save(tts_fn, model_output['tts_speech'], sample_rate=22050)
+ f.write('{} {}\n'.format(tts_key, tts_fn))
+ f.flush()
+ f.close()
+ logging.info('Result wav.scp saved in {}'.format(fn))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/cosyvoice/bin/train.py b/cosyvoice/bin/train.py
new file mode 100755
index 0000000000000000000000000000000000000000..a9d0e0581d81a8964683dea4af2fd0f407eab5e8
--- /dev/null
+++ b/cosyvoice/bin/train.py
@@ -0,0 +1,136 @@
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+import argparse
+import datetime
+import logging
+logging.getLogger('matplotlib').setLevel(logging.WARNING)
+from copy import deepcopy
+import torch
+import torch.distributed as dist
+import deepspeed
+
+from hyperpyyaml import load_hyperpyyaml
+
+from torch.distributed.elastic.multiprocessing.errors import record
+
+from cosyvoice.utils.executor import Executor
+from cosyvoice.utils.train_utils import (
+ init_distributed,
+ init_dataset_and_dataloader,
+ init_optimizer_and_scheduler,
+ init_summarywriter, save_model,
+ wrap_cuda_model, check_modify_and_save_config)
+
+
+def get_args():
+ parser = argparse.ArgumentParser(description='training your network')
+ parser.add_argument('--train_engine',
+ default='torch_ddp',
+ choices=['torch_ddp', 'deepspeed'],
+ help='Engine for paralleled training')
+ parser.add_argument('--model', required=True, help='model which will be trained')
+ parser.add_argument('--config', required=True, help='config file')
+ parser.add_argument('--train_data', required=True, help='train data file')
+ parser.add_argument('--cv_data', required=True, help='cv data file')
+ parser.add_argument('--checkpoint', help='checkpoint model')
+ parser.add_argument('--model_dir', required=True, help='save model dir')
+ parser.add_argument('--tensorboard_dir',
+ default='tensorboard',
+ help='tensorboard log dir')
+ parser.add_argument('--ddp.dist_backend',
+ dest='dist_backend',
+ default='nccl',
+ choices=['nccl', 'gloo'],
+ help='distributed backend')
+ parser.add_argument('--num_workers',
+ default=0,
+ type=int,
+ help='num of subprocess workers for reading')
+ parser.add_argument('--prefetch',
+ default=100,
+ type=int,
+ help='prefetch number')
+ parser.add_argument('--pin_memory',
+ action='store_true',
+ default=False,
+ help='Use pinned memory buffers used for reading')
+ parser.add_argument('--deepspeed.save_states',
+ dest='save_states',
+ default='model_only',
+ choices=['model_only', 'model+optimizer'],
+ help='save model/optimizer states')
+ parser.add_argument('--timeout',
+ default=30,
+ type=int,
+ help='timeout (in seconds) of cosyvoice_join.')
+ parser = deepspeed.add_config_arguments(parser)
+ args = parser.parse_args()
+ return args
+
+
+@record
+def main():
+ args = get_args()
+ logging.basicConfig(level=logging.DEBUG,
+ format='%(asctime)s %(levelname)s %(message)s')
+
+ override_dict = {k: None for k in ['llm', 'flow', 'hift'] if k != args.model}
+ with open(args.config, 'r') as f:
+ configs = load_hyperpyyaml(f, overrides=override_dict)
+ configs['train_conf'].update(vars(args))
+
+ # Init env for ddp
+ init_distributed(args)
+
+ # Get dataset & dataloader
+ train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
+ init_dataset_and_dataloader(args, configs)
+
+ # Do some sanity checks and save config to arsg.model_dir
+ configs = check_modify_and_save_config(args, configs)
+
+ # Tensorboard summary
+ writer = init_summarywriter(args)
+
+ # load checkpoint
+ model = configs[args.model]
+ if args.checkpoint is not None:
+ model.load_state_dict(torch.load(args.checkpoint, map_location='cpu'))
+
+ # Dispatch model from cpu to gpu
+ model = wrap_cuda_model(args, model)
+
+ # Get optimizer & scheduler
+ model, optimizer, scheduler = init_optimizer_and_scheduler(args, configs, model)
+
+ # Save init checkpoints
+ info_dict = deepcopy(configs['train_conf'])
+ save_model(model, 'init', info_dict)
+
+ # Get executor
+ executor = Executor()
+
+ # Start training loop
+ for epoch in range(info_dict['max_epoch']):
+ executor.epoch = epoch
+ train_dataset.set_epoch(epoch)
+ dist.barrier()
+ group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
+ executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join)
+ dist.destroy_process_group(group_join)
+
+if __name__ == '__main__':
+ main()
diff --git a/cosyvoice/cli/__init__.py b/cosyvoice/cli/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/cosyvoice/cli/__pycache__/__init__.cpython-310.pyc b/cosyvoice/cli/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..88fe9a437aa0e1ce0c953d0dda98dc881805a94c
Binary files /dev/null and b/cosyvoice/cli/__pycache__/__init__.cpython-310.pyc differ
diff --git a/cosyvoice/cli/__pycache__/__init__.cpython-38.pyc b/cosyvoice/cli/__pycache__/__init__.cpython-38.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..0b17d8ae1151befb71a194cc899fb6fa5b2987f2
Binary files /dev/null and b/cosyvoice/cli/__pycache__/__init__.cpython-38.pyc differ
diff --git a/cosyvoice/cli/__pycache__/cosyvoice.cpython-310.pyc b/cosyvoice/cli/__pycache__/cosyvoice.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..98fd1254857696d925c565710dd6fa2203c4c772
Binary files /dev/null and b/cosyvoice/cli/__pycache__/cosyvoice.cpython-310.pyc differ
diff --git a/cosyvoice/cli/__pycache__/cosyvoice.cpython-38.pyc b/cosyvoice/cli/__pycache__/cosyvoice.cpython-38.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..871d57d7c6ab704a74629a024353efc322a98647
Binary files /dev/null and b/cosyvoice/cli/__pycache__/cosyvoice.cpython-38.pyc differ
diff --git a/cosyvoice/cli/__pycache__/frontend.cpython-310.pyc b/cosyvoice/cli/__pycache__/frontend.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bb1e9fd4489993f3433ceb33cd822aa05dcb11b0
Binary files /dev/null and b/cosyvoice/cli/__pycache__/frontend.cpython-310.pyc differ
diff --git a/cosyvoice/cli/__pycache__/frontend.cpython-38.pyc b/cosyvoice/cli/__pycache__/frontend.cpython-38.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..57e3e47696770f0e86036ab6c6a644bc8d5a57fe
Binary files /dev/null and b/cosyvoice/cli/__pycache__/frontend.cpython-38.pyc differ
diff --git a/cosyvoice/cli/__pycache__/model.cpython-310.pyc b/cosyvoice/cli/__pycache__/model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..625589637de20f2d1fb5e12c50d7880eb6b26f87
Binary files /dev/null and b/cosyvoice/cli/__pycache__/model.cpython-310.pyc differ
diff --git a/cosyvoice/cli/__pycache__/model.cpython-38.pyc b/cosyvoice/cli/__pycache__/model.cpython-38.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..913ce90bbb03b1c33f2d0a225c75abbd4f2d5570
Binary files /dev/null and b/cosyvoice/cli/__pycache__/model.cpython-38.pyc differ
diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py
new file mode 100755
index 0000000000000000000000000000000000000000..25743a6a8b747061e4563f2eb62da3276fd19cce
--- /dev/null
+++ b/cosyvoice/cli/cosyvoice.py
@@ -0,0 +1,83 @@
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import torch
+from hyperpyyaml import load_hyperpyyaml
+from huggingface_hub import snapshot_download
+from cosyvoice.cli.frontend import CosyVoiceFrontEnd
+from cosyvoice.cli.model import CosyVoiceModel
+
+class CosyVoice:
+
+ def __init__(self, model_dir):
+ instruct = True if '-Instruct' in model_dir else False
+ self.model_dir = model_dir
+ if not os.path.exists(model_dir):
+ model_dir = snapshot_download(model_dir)
+ with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
+ configs = load_hyperpyyaml(f)
+ self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
+ configs['feat_extractor'],
+ '{}/campplus.onnx'.format(model_dir),
+ '{}/speech_tokenizer_v1.onnx'.format(model_dir),
+ '{}/spk2info.pt'.format(model_dir),
+ instruct,
+ configs['allowed_special'])
+ self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
+ self.model.load('{}/llm.pt'.format(model_dir),
+ '{}/flow.pt'.format(model_dir),
+ '{}/hift.pt'.format(model_dir))
+ del configs
+
+ def list_avaliable_spks(self):
+ spks = list(self.frontend.spk2info.keys())
+ return spks
+
+ def inference_sft(self, tts_text, spk_id):
+ tts_speeches = []
+ for i in self.frontend.text_normalize(tts_text, split=True):
+ model_input = self.frontend.frontend_sft(i, spk_id)
+ model_output = self.model.inference(**model_input)
+ tts_speeches.append(model_output['tts_speech'])
+ return {'tts_speech': torch.concat(tts_speeches, dim=1)}
+
+ def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k):
+ prompt_text = self.frontend.text_normalize(prompt_text, split=False)
+ tts_speeches = []
+ for i in self.frontend.text_normalize(tts_text, split=True):
+ model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k)
+ model_output = self.model.inference(**model_input)
+ tts_speeches.append(model_output['tts_speech'])
+ return {'tts_speech': torch.concat(tts_speeches, dim=1)}
+
+ def inference_cross_lingual(self, tts_text, prompt_speech_16k):
+ if self.frontend.instruct is True:
+ raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
+ tts_speeches = []
+ for i in self.frontend.text_normalize(tts_text, split=True):
+ model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k)
+ model_output = self.model.inference(**model_input)
+ tts_speeches.append(model_output['tts_speech'])
+ return {'tts_speech': torch.concat(tts_speeches, dim=1)}
+
+ def inference_instruct(self, tts_text, spk_id, instruct_text):
+ if self.frontend.instruct is False:
+ raise ValueError('{} do not support instruct inference'.format(self.model_dir))
+ instruct_text = self.frontend.text_normalize(instruct_text, split=False)
+ tts_speeches = []
+ for i in self.frontend.text_normalize(tts_text, split=True):
+ model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
+ model_output = self.model.inference(**model_input)
+ tts_speeches.append(model_output['tts_speech'])
+ return {'tts_speech': torch.concat(tts_speeches, dim=1)}
diff --git a/cosyvoice/cli/frontend.py b/cosyvoice/cli/frontend.py
new file mode 100755
index 0000000000000000000000000000000000000000..4e4f8c2a08c2ceda88854f1d196bcd28bbe6681c
--- /dev/null
+++ b/cosyvoice/cli/frontend.py
@@ -0,0 +1,183 @@
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from functools import partial
+import onnxruntime
+import torch
+import numpy as np
+import whisper
+from typing import Callable
+import torchaudio.compliance.kaldi as kaldi
+import torchaudio
+import os
+import re
+import inflect
+import subprocess
+try:
+ import ttsfrd
+ use_ttsfrd = True
+except ImportError:
+ print("failed to import ttsfrd, use WeTextProcessing instead")
+ from tn.chinese.normalizer import Normalizer as ZhNormalizer
+ from tn.english.normalizer import Normalizer as EnNormalizer
+ use_ttsfrd = False
+from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph
+
+
+class CosyVoiceFrontEnd:
+
+ def __init__(self,
+ get_tokenizer: Callable,
+ feat_extractor: Callable,
+ campplus_model: str,
+ speech_tokenizer_model: str,
+ spk2info: str = '',
+ instruct: bool = False,
+ allowed_special: str = 'all'):
+ self.tokenizer = get_tokenizer()
+ self.feat_extractor = feat_extractor
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ option = onnxruntime.SessionOptions()
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
+ option.intra_op_num_threads = 1
+ self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
+ self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option, providers=["CUDAExecutionProvider"if torch.cuda.is_available() else "CPUExecutionProvider"])
+ if os.path.exists(spk2info):
+ self.spk2info = torch.load(spk2info, map_location=self.device)
+ self.instruct = instruct
+ self.allowed_special = allowed_special
+ self.inflect_parser = inflect.engine()
+ self.use_ttsfrd = use_ttsfrd
+ if self.use_ttsfrd:
+ self.frd = ttsfrd.TtsFrontendEngine()
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
+ #print("LOCATION",ttsfrd.__file__)
+ #print('TTSFRD FILES',os.listdir(ttsfrd.__file__))
+ if not os.path.exists('resource.zip'):
+ # Download the file if it does not exist
+ subprocess.run("wget https://huggingface.co/FunAudioLLM/CosyVoice-ttsfrd/resolve/main/resource.zip".split())
+
+ # Unzip the file if it exists
+ if not os.path.exists('resource'):
+ subprocess.run("unzip resource.zip".split())
+ else:
+ pass
+ #print(os.listdir())
+ #print(subprocess.run("pwd"))
+ print("root",ROOT_DIR)
+ assert self.frd.initialize('{}/../../resource'.format(ROOT_DIR)) is True, 'failed to initialize ttsfrd resource'
+ self.frd.set_lang_type('pinyin')
+ self.frd.enable_pinyin_mix(True)
+ self.frd.set_breakmodel_index(1)
+ else:
+ self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False)
+ self.en_tn_model = EnNormalizer()
+
+ def _extract_text_token(self, text):
+ text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
+ text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
+ text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
+ return text_token, text_token_len
+
+ def _extract_speech_token(self, speech):
+ feat = whisper.log_mel_spectrogram(speech, n_mels=128)
+ speech_token = self.speech_tokenizer_session.run(None, {self.speech_tokenizer_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
+ self.speech_tokenizer_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
+ speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
+ speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
+ return speech_token, speech_token_len
+
+ def _extract_spk_embedding(self, speech):
+ feat = kaldi.fbank(speech,
+ num_mel_bins=80,
+ dither=0,
+ sample_frequency=16000)
+ feat = feat - feat.mean(dim=0, keepdim=True)
+ embedding = self.campplus_session.run(None, {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
+ embedding = torch.tensor([embedding]).to(self.device)
+ return embedding
+
+ def _extract_speech_feat(self, speech):
+ speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
+ speech_feat = speech_feat.unsqueeze(dim=0)
+ speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
+ return speech_feat, speech_feat_len
+
+ def text_normalize(self, text, split=True):
+ text = text.strip()
+ if contains_chinese(text):
+ if self.use_ttsfrd:
+ text = self.frd.get_frd_extra_info(text, 'input')
+ else:
+ text = self.zh_tn_model.normalize(text)
+ text = text.replace("\n", "")
+ text = replace_blank(text)
+ text = replace_corner_mark(text)
+ text = text.replace(".", "、")
+ text = text.replace(" - ", ",")
+ text = remove_bracket(text)
+ text = re.sub(r'[,,]+$', '。', text)
+ texts = [i for i in split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
+ token_min_n=60, merge_len=20,
+ comma_split=False)]
+ else:
+ if self.use_ttsfrd:
+ text = self.frd.get_frd_extra_info(text, 'input')
+ else:
+ text = self.en_tn_model.normalize(text)
+ text = spell_out_number(text, self.inflect_parser)
+ texts = [i for i in split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
+ token_min_n=60, merge_len=20,
+ comma_split=False)]
+ if split is False:
+ return text
+ return texts
+
+ def frontend_sft(self, tts_text, spk_id):
+ tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
+ embedding = self.spk2info[spk_id]['embedding']
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
+ return model_input
+
+ def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k):
+ tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
+ prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
+ prompt_speech_22050 = torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050)(prompt_speech_16k)
+ speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_22050)
+ speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
+ embedding = self._extract_spk_embedding(prompt_speech_16k)
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
+ 'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
+ 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
+ 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
+ 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
+ 'llm_embedding': embedding, 'flow_embedding': embedding}
+ return model_input
+
+ def frontend_cross_lingual(self, tts_text, prompt_speech_16k):
+ model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k)
+ # in cross lingual mode, we remove prompt in llm
+ del model_input['prompt_text']
+ del model_input['prompt_text_len']
+ del model_input['llm_prompt_speech_token']
+ del model_input['llm_prompt_speech_token_len']
+ return model_input
+
+ def frontend_instruct(self, tts_text, spk_id, instruct_text):
+ model_input = self.frontend_sft(tts_text, spk_id)
+ # in instruct mode, we remove spk_embedding in llm due to information leakage
+ del model_input['llm_embedding']
+ instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text + '
+
+