Spaces:
Running
on
Zero
Running
on
Zero
Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
- api.py +2 -2
- inference-cli.py +2 -2
- model/utils_infer.py +22 -10
api.py
CHANGED
|
@@ -33,10 +33,10 @@ class F5TTS:
|
|
| 33 |
)
|
| 34 |
|
| 35 |
# Load models
|
| 36 |
-
self.
|
| 37 |
self.load_ema_model(model_type, ckpt_file, vocab_file, ode_method, use_ema)
|
| 38 |
|
| 39 |
-
def
|
| 40 |
self.vocos = load_vocoder(local_path is not None, local_path, self.device)
|
| 41 |
|
| 42 |
def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema):
|
|
|
|
| 33 |
)
|
| 34 |
|
| 35 |
# Load models
|
| 36 |
+
self.load_vocoder_model(local_path)
|
| 37 |
self.load_ema_model(model_type, ckpt_file, vocab_file, ode_method, use_ema)
|
| 38 |
|
| 39 |
+
def load_vocoder_model(self, local_path):
|
| 40 |
self.vocos = load_vocoder(local_path is not None, local_path, self.device)
|
| 41 |
|
| 42 |
def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema):
|
inference-cli.py
CHANGED
|
@@ -104,7 +104,7 @@ if model == "F5-TTS":
|
|
| 104 |
exp_name = "F5TTS_Base"
|
| 105 |
ckpt_step = 1200000
|
| 106 |
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
| 107 |
-
#
|
| 108 |
|
| 109 |
elif model == "E2-TTS":
|
| 110 |
model_cls = UNetT
|
|
@@ -114,7 +114,7 @@ elif model == "E2-TTS":
|
|
| 114 |
exp_name = "E2TTS_Base"
|
| 115 |
ckpt_step = 1200000
|
| 116 |
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
| 117 |
-
#
|
| 118 |
|
| 119 |
print(f"Using {model}...")
|
| 120 |
ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file)
|
|
|
|
| 104 |
exp_name = "F5TTS_Base"
|
| 105 |
ckpt_step = 1200000
|
| 106 |
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
| 107 |
+
# ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
|
| 108 |
|
| 109 |
elif model == "E2-TTS":
|
| 110 |
model_cls = UNetT
|
|
|
|
| 114 |
exp_name = "E2TTS_Base"
|
| 115 |
ckpt_step = 1200000
|
| 116 |
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
| 117 |
+
# ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
|
| 118 |
|
| 119 |
print(f"Using {model}...")
|
| 120 |
ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file)
|
model/utils_infer.py
CHANGED
|
@@ -22,13 +22,6 @@ from model.utils import (
|
|
| 22 |
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
| 23 |
print(f"Using {device} device")
|
| 24 |
|
| 25 |
-
asr_pipe = pipeline(
|
| 26 |
-
"automatic-speech-recognition",
|
| 27 |
-
model="openai/whisper-large-v3-turbo",
|
| 28 |
-
torch_dtype=torch.float16,
|
| 29 |
-
device=device,
|
| 30 |
-
)
|
| 31 |
-
|
| 32 |
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
| 33 |
|
| 34 |
|
|
@@ -82,8 +75,6 @@ def chunk_text(text, max_chars=135):
|
|
| 82 |
|
| 83 |
|
| 84 |
# load vocoder
|
| 85 |
-
|
| 86 |
-
|
| 87 |
def load_vocoder(is_local=False, local_path="", device=device):
|
| 88 |
if is_local:
|
| 89 |
print(f"Load vocos from local path {local_path}")
|
|
@@ -97,6 +88,22 @@ def load_vocoder(is_local=False, local_path="", device=device):
|
|
| 97 |
return vocos
|
| 98 |
|
| 99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
# load model for inference
|
| 101 |
|
| 102 |
|
|
@@ -133,7 +140,7 @@ def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method="euler
|
|
| 133 |
# preprocess reference audio and text
|
| 134 |
|
| 135 |
|
| 136 |
-
def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print):
|
| 137 |
show_info("Converting audio...")
|
| 138 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
| 139 |
aseg = AudioSegment.from_file(ref_audio_orig)
|
|
@@ -152,6 +159,9 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print):
|
|
| 152 |
ref_audio = f.name
|
| 153 |
|
| 154 |
if not ref_text.strip():
|
|
|
|
|
|
|
|
|
|
| 155 |
show_info("No reference text provided, transcribing reference audio...")
|
| 156 |
ref_text = asr_pipe(
|
| 157 |
ref_audio,
|
|
@@ -329,6 +339,8 @@ def infer_batch_process(
|
|
| 329 |
|
| 330 |
|
| 331 |
# remove silence from generated wav
|
|
|
|
|
|
|
| 332 |
def remove_silence_for_generated_wav(filename):
|
| 333 |
aseg = AudioSegment.from_file(filename)
|
| 334 |
non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
|
|
|
|
| 22 |
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
| 23 |
print(f"Using {device} device")
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
| 26 |
|
| 27 |
|
|
|
|
| 75 |
|
| 76 |
|
| 77 |
# load vocoder
|
|
|
|
|
|
|
| 78 |
def load_vocoder(is_local=False, local_path="", device=device):
|
| 79 |
if is_local:
|
| 80 |
print(f"Load vocos from local path {local_path}")
|
|
|
|
| 88 |
return vocos
|
| 89 |
|
| 90 |
|
| 91 |
+
# load asr pipeline
|
| 92 |
+
|
| 93 |
+
asr_pipe = None
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def initialize_asr_pipeline(device=device):
|
| 97 |
+
global asr_pipe
|
| 98 |
+
|
| 99 |
+
asr_pipe = pipeline(
|
| 100 |
+
"automatic-speech-recognition",
|
| 101 |
+
model="openai/whisper-large",
|
| 102 |
+
torch_dtype=torch.float16,
|
| 103 |
+
device=device,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
# load model for inference
|
| 108 |
|
| 109 |
|
|
|
|
| 140 |
# preprocess reference audio and text
|
| 141 |
|
| 142 |
|
| 143 |
+
def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print, device=device):
|
| 144 |
show_info("Converting audio...")
|
| 145 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
| 146 |
aseg = AudioSegment.from_file(ref_audio_orig)
|
|
|
|
| 159 |
ref_audio = f.name
|
| 160 |
|
| 161 |
if not ref_text.strip():
|
| 162 |
+
global asr_pipe
|
| 163 |
+
if asr_pipe is None:
|
| 164 |
+
initialize_asr_pipeline(device=device)
|
| 165 |
show_info("No reference text provided, transcribing reference audio...")
|
| 166 |
ref_text = asr_pipe(
|
| 167 |
ref_audio,
|
|
|
|
| 339 |
|
| 340 |
|
| 341 |
# remove silence from generated wav
|
| 342 |
+
|
| 343 |
+
|
| 344 |
def remove_silence_for_generated_wav(filename):
|
| 345 |
aseg = AudioSegment.from_file(filename)
|
| 346 |
non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
|