Spaces:
Runtime error
Runtime error
Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
- pyproject.toml +1 -0
- src/f5_tts/configs/E2TTS_Base_train.yaml +43 -0
- src/f5_tts/configs/E2TTS_Small_train.yaml +43 -0
- src/f5_tts/configs/F5TTS_Base_train.yaml +45 -0
- src/f5_tts/configs/F5TTS_Small_train.yaml +45 -0
- src/f5_tts/eval/README.md +2 -2
- src/f5_tts/eval/eval_infer_batch.py +2 -2
- src/f5_tts/eval/eval_librispeech_test_clean.py +63 -52
- src/f5_tts/eval/eval_seedtts_testset.py +63 -54
- src/f5_tts/train/README.md +4 -1
- src/f5_tts/train/datasets/prepare_ljspeech.py +64 -0
- src/f5_tts/train/train.py +39 -68
pyproject.toml
CHANGED
|
@@ -39,6 +39,7 @@ dependencies = [
|
|
| 39 |
"vocos",
|
| 40 |
"wandb",
|
| 41 |
"x_transformers>=1.31.14",
|
|
|
|
| 42 |
]
|
| 43 |
|
| 44 |
[project.optional-dependencies]
|
|
|
|
| 39 |
"vocos",
|
| 40 |
"wandb",
|
| 41 |
"x_transformers>=1.31.14",
|
| 42 |
+
"hydra-core>=1.3.0",
|
| 43 |
]
|
| 44 |
|
| 45 |
[project.optional-dependencies]
|
src/f5_tts/configs/E2TTS_Base_train.yaml
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
hydra:
|
| 2 |
+
run:
|
| 3 |
+
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
| 4 |
+
|
| 5 |
+
datasets:
|
| 6 |
+
name: Emilia_ZH_EN # dataset name
|
| 7 |
+
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
|
| 8 |
+
batch_size_type: frame # "frame" or "sample"
|
| 9 |
+
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
| 10 |
+
num_workers: 16 # number of workers
|
| 11 |
+
|
| 12 |
+
optim:
|
| 13 |
+
epochs: 15 # max epochs
|
| 14 |
+
learning_rate: 7.5e-5 # learning rate
|
| 15 |
+
num_warmup_updates: 20000 # warmup steps
|
| 16 |
+
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
| 17 |
+
max_grad_norm: 1.0 # gradient clipping
|
| 18 |
+
bnb_optimizer: False # use bnb optimizer or not
|
| 19 |
+
|
| 20 |
+
model:
|
| 21 |
+
name: E2TTS_Base # model name
|
| 22 |
+
tokenizer: pinyin # tokenizer type
|
| 23 |
+
tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
|
| 24 |
+
arch:
|
| 25 |
+
dim: 1024 # model dimension
|
| 26 |
+
depth: 24 # number of transformer layers
|
| 27 |
+
heads: 16 # number of transformer heads
|
| 28 |
+
ff_mult: 4 # ff layer expansion
|
| 29 |
+
mel_spec:
|
| 30 |
+
target_sample_rate: 24000 # target sample rate
|
| 31 |
+
n_mel_channels: 100 # mel channel
|
| 32 |
+
hop_length: 256 # hop length
|
| 33 |
+
win_length: 1024 # window length
|
| 34 |
+
n_fft: 1024 # fft length
|
| 35 |
+
mel_spec_type: vocos # 'vocos' or 'bigvgan'
|
| 36 |
+
is_local_vocoder: False # use local vocoder or not
|
| 37 |
+
local_vocoder_path: None # path to local vocoder
|
| 38 |
+
|
| 39 |
+
ckpts:
|
| 40 |
+
logger: wandb # wandb | tensorboard | None
|
| 41 |
+
save_per_updates: 50000 # save checkpoint per steps
|
| 42 |
+
last_per_steps: 5000 # save last checkpoint per steps
|
| 43 |
+
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
src/f5_tts/configs/E2TTS_Small_train.yaml
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
hydra:
|
| 2 |
+
run:
|
| 3 |
+
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
| 4 |
+
|
| 5 |
+
datasets:
|
| 6 |
+
name: Emilia_ZH_EN
|
| 7 |
+
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
|
| 8 |
+
batch_size_type: frame # "frame" or "sample"
|
| 9 |
+
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
| 10 |
+
num_workers: 16 # number of workers
|
| 11 |
+
|
| 12 |
+
optim:
|
| 13 |
+
epochs: 15
|
| 14 |
+
learning_rate: 7.5e-5
|
| 15 |
+
num_warmup_updates: 20000 # warmup steps
|
| 16 |
+
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
| 17 |
+
max_grad_norm: 1.0
|
| 18 |
+
bnb_optimizer: False
|
| 19 |
+
|
| 20 |
+
model:
|
| 21 |
+
name: E2TTS_Small
|
| 22 |
+
tokenizer: pinyin
|
| 23 |
+
tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
|
| 24 |
+
arch:
|
| 25 |
+
dim: 768
|
| 26 |
+
depth: 20
|
| 27 |
+
heads: 12
|
| 28 |
+
ff_mult: 4
|
| 29 |
+
mel_spec:
|
| 30 |
+
target_sample_rate: 24000
|
| 31 |
+
n_mel_channels: 100
|
| 32 |
+
hop_length: 256
|
| 33 |
+
win_length: 1024
|
| 34 |
+
n_fft: 1024
|
| 35 |
+
mel_spec_type: vocos # 'vocos' or 'bigvgan'
|
| 36 |
+
is_local_vocoder: False
|
| 37 |
+
local_vocoder_path: None
|
| 38 |
+
|
| 39 |
+
ckpts:
|
| 40 |
+
logger: wandb # wandb | tensorboard | None
|
| 41 |
+
save_per_updates: 50000 # save checkpoint per steps
|
| 42 |
+
last_per_steps: 5000 # save last checkpoint per steps
|
| 43 |
+
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
src/f5_tts/configs/F5TTS_Base_train.yaml
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
hydra:
|
| 2 |
+
run:
|
| 3 |
+
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
| 4 |
+
|
| 5 |
+
datasets:
|
| 6 |
+
name: Emilia_ZH_EN # dataset name
|
| 7 |
+
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
|
| 8 |
+
batch_size_type: frame # "frame" or "sample"
|
| 9 |
+
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
| 10 |
+
num_workers: 16 # number of workers
|
| 11 |
+
|
| 12 |
+
optim:
|
| 13 |
+
epochs: 15 # max epochs
|
| 14 |
+
learning_rate: 7.5e-5 # learning rate
|
| 15 |
+
num_warmup_updates: 20000 # warmup steps
|
| 16 |
+
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
| 17 |
+
max_grad_norm: 1.0 # gradient clipping
|
| 18 |
+
bnb_optimizer: False # use bnb optimizer or not
|
| 19 |
+
|
| 20 |
+
model:
|
| 21 |
+
name: F5TTS_Base # model name
|
| 22 |
+
tokenizer: pinyin # tokenizer type
|
| 23 |
+
tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
|
| 24 |
+
arch:
|
| 25 |
+
dim: 1024 # model dim
|
| 26 |
+
depth: 22 # model depth
|
| 27 |
+
heads: 16 # model heads
|
| 28 |
+
ff_mult: 2 # feedforward expansion
|
| 29 |
+
text_dim: 512 # text encoder dim
|
| 30 |
+
conv_layers: 4 # convolution layers
|
| 31 |
+
mel_spec:
|
| 32 |
+
target_sample_rate: 24000 # target sample rate
|
| 33 |
+
n_mel_channels: 100 # mel channel
|
| 34 |
+
hop_length: 256 # hop length
|
| 35 |
+
win_length: 1024 # window length
|
| 36 |
+
n_fft: 1024 # fft length
|
| 37 |
+
mel_spec_type: vocos # 'vocos' or 'bigvgan'
|
| 38 |
+
is_local_vocoder: False # use local vocoder or not
|
| 39 |
+
local_vocoder_path: None # local vocoder path
|
| 40 |
+
|
| 41 |
+
ckpts:
|
| 42 |
+
logger: wandb # wandb | tensorboard | None
|
| 43 |
+
save_per_updates: 50000 # save checkpoint per steps
|
| 44 |
+
last_per_steps: 5000 # save last checkpoint per steps
|
| 45 |
+
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
src/f5_tts/configs/F5TTS_Small_train.yaml
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
hydra:
|
| 2 |
+
run:
|
| 3 |
+
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
| 4 |
+
|
| 5 |
+
datasets:
|
| 6 |
+
name: Emilia_ZH_EN
|
| 7 |
+
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
|
| 8 |
+
batch_size_type: frame # "frame" or "sample"
|
| 9 |
+
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
| 10 |
+
num_workers: 16 # number of workers
|
| 11 |
+
|
| 12 |
+
optim:
|
| 13 |
+
epochs: 15
|
| 14 |
+
learning_rate: 7.5e-5
|
| 15 |
+
num_warmup_updates: 20000 # warmup steps
|
| 16 |
+
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
| 17 |
+
max_grad_norm: 1.0
|
| 18 |
+
bnb_optimizer: False
|
| 19 |
+
|
| 20 |
+
model:
|
| 21 |
+
name: F5TTS_Small
|
| 22 |
+
tokenizer: pinyin
|
| 23 |
+
tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
|
| 24 |
+
arch:
|
| 25 |
+
dim: 768
|
| 26 |
+
depth: 18
|
| 27 |
+
heads: 12
|
| 28 |
+
ff_mult: 2
|
| 29 |
+
text_dim: 512
|
| 30 |
+
conv_layers: 4
|
| 31 |
+
mel_spec:
|
| 32 |
+
target_sample_rate: 24000
|
| 33 |
+
n_mel_channels: 100
|
| 34 |
+
hop_length: 256
|
| 35 |
+
win_length: 1024
|
| 36 |
+
n_fft: 1024
|
| 37 |
+
mel_spec_type: vocos # 'vocos' or 'bigvgan'
|
| 38 |
+
is_local_vocoder: False
|
| 39 |
+
local_vocoder_path: None
|
| 40 |
+
|
| 41 |
+
ckpts:
|
| 42 |
+
logger: wandb # wandb | tensorboard | None
|
| 43 |
+
save_per_updates: 50000 # save checkpoint per steps
|
| 44 |
+
last_per_steps: 5000 # save last checkpoint per steps
|
| 45 |
+
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
src/f5_tts/eval/README.md
CHANGED
|
@@ -42,8 +42,8 @@ Then update in the following scripts with the paths you put evaluation model ckp
|
|
| 42 |
Update the path with your batch-inferenced results, and carry out WER / SIM evaluations:
|
| 43 |
```bash
|
| 44 |
# Evaluation for Seed-TTS test set
|
| 45 |
-
python src/f5_tts/eval/eval_seedtts_testset.py
|
| 46 |
|
| 47 |
# Evaluation for LibriSpeech-PC test-clean (cross-sentence)
|
| 48 |
-
python src/f5_tts/eval/eval_librispeech_test_clean.py
|
| 49 |
```
|
|
|
|
| 42 |
Update the path with your batch-inferenced results, and carry out WER / SIM evaluations:
|
| 43 |
```bash
|
| 44 |
# Evaluation for Seed-TTS test set
|
| 45 |
+
python src/f5_tts/eval/eval_seedtts_testset.py --gen_wav_dir <GEN_WAVE_DIR>
|
| 46 |
|
| 47 |
# Evaluation for LibriSpeech-PC test-clean (cross-sentence)
|
| 48 |
+
python src/f5_tts/eval/eval_librispeech_test_clean.py --gen_wav_dir <GEN_WAVE_DIR> --librispeech_test_clean_path <TEST_CLEAN_PATH>
|
| 49 |
```
|
src/f5_tts/eval/eval_infer_batch.py
CHANGED
|
@@ -34,8 +34,6 @@ win_length = 1024
|
|
| 34 |
n_fft = 1024
|
| 35 |
target_rms = 0.1
|
| 36 |
|
| 37 |
-
|
| 38 |
-
tokenizer = "pinyin"
|
| 39 |
rel_path = str(files("f5_tts").joinpath("../../"))
|
| 40 |
|
| 41 |
|
|
@@ -49,6 +47,7 @@ def main():
|
|
| 49 |
parser.add_argument("-n", "--expname", required=True)
|
| 50 |
parser.add_argument("-c", "--ckptstep", default=1200000, type=int)
|
| 51 |
parser.add_argument("-m", "--mel_spec_type", default="vocos", type=str, choices=["bigvgan", "vocos"])
|
|
|
|
| 52 |
|
| 53 |
parser.add_argument("-nfe", "--nfestep", default=32, type=int)
|
| 54 |
parser.add_argument("-o", "--odemethod", default="euler")
|
|
@@ -64,6 +63,7 @@ def main():
|
|
| 64 |
ckpt_step = args.ckptstep
|
| 65 |
ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt"
|
| 66 |
mel_spec_type = args.mel_spec_type
|
|
|
|
| 67 |
|
| 68 |
nfe_step = args.nfestep
|
| 69 |
ode_method = args.odemethod
|
|
|
|
| 34 |
n_fft = 1024
|
| 35 |
target_rms = 0.1
|
| 36 |
|
|
|
|
|
|
|
| 37 |
rel_path = str(files("f5_tts").joinpath("../../"))
|
| 38 |
|
| 39 |
|
|
|
|
| 47 |
parser.add_argument("-n", "--expname", required=True)
|
| 48 |
parser.add_argument("-c", "--ckptstep", default=1200000, type=int)
|
| 49 |
parser.add_argument("-m", "--mel_spec_type", default="vocos", type=str, choices=["bigvgan", "vocos"])
|
| 50 |
+
parser.add_argument("-to", "--tokenizer", default="pinyin", type=str, choices=["pinyin", "char"])
|
| 51 |
|
| 52 |
parser.add_argument("-nfe", "--nfestep", default=32, type=int)
|
| 53 |
parser.add_argument("-o", "--odemethod", default="euler")
|
|
|
|
| 63 |
ckpt_step = args.ckptstep
|
| 64 |
ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt"
|
| 65 |
mel_spec_type = args.mel_spec_type
|
| 66 |
+
tokenizer = args.tokenizer
|
| 67 |
|
| 68 |
nfe_step = args.nfestep
|
| 69 |
ode_method = args.odemethod
|
src/f5_tts/eval/eval_librispeech_test_clean.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
| 2 |
|
| 3 |
import sys
|
| 4 |
import os
|
|
|
|
| 5 |
|
| 6 |
sys.path.append(os.getcwd())
|
| 7 |
|
|
@@ -19,55 +20,65 @@ from f5_tts.eval.utils_eval import (
|
|
| 19 |
rel_path = str(files("f5_tts").joinpath("../../"))
|
| 20 |
|
| 21 |
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
import sys
|
| 4 |
import os
|
| 5 |
+
import argparse
|
| 6 |
|
| 7 |
sys.path.append(os.getcwd())
|
| 8 |
|
|
|
|
| 20 |
rel_path = str(files("f5_tts").joinpath("../../"))
|
| 21 |
|
| 22 |
|
| 23 |
+
def get_args():
|
| 24 |
+
parser = argparse.ArgumentParser()
|
| 25 |
+
parser.add_argument("-e", "--eval_task", type=str, default="wer", choices=["sim", "wer"])
|
| 26 |
+
parser.add_argument("-l", "--lang", type=str, default="en")
|
| 27 |
+
parser.add_argument("-g", "--gen_wav_dir", type=str, required=True)
|
| 28 |
+
parser.add_argument("-p", "--librispeech_test_clean_path", type=str, required=True)
|
| 29 |
+
parser.add_argument("-n", "--gpu_nums", type=int, default=8, help="Number of GPUs to use")
|
| 30 |
+
parser.add_argument("--local", action="store_true", help="Use local custom checkpoint directory")
|
| 31 |
+
return parser.parse_args()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def main():
|
| 35 |
+
args = get_args()
|
| 36 |
+
eval_task = args.eval_task
|
| 37 |
+
lang = args.lang
|
| 38 |
+
librispeech_test_clean_path = args.librispeech_test_clean_path # test-clean path
|
| 39 |
+
gen_wav_dir = args.gen_wav_dir
|
| 40 |
+
metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
|
| 41 |
+
|
| 42 |
+
gpus = list(range(args.gpu_nums))
|
| 43 |
+
test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path)
|
| 44 |
+
|
| 45 |
+
## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book,
|
| 46 |
+
## leading to a low similarity for the ground truth in some cases.
|
| 47 |
+
# test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth = True) # eval ground truth
|
| 48 |
+
|
| 49 |
+
local = args.local
|
| 50 |
+
if local: # use local custom checkpoint dir
|
| 51 |
+
asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
|
| 52 |
+
else:
|
| 53 |
+
asr_ckpt_dir = "" # auto download to cache dir
|
| 54 |
+
wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
|
| 55 |
+
|
| 56 |
+
# --------------------------- WER ---------------------------
|
| 57 |
+
if eval_task == "wer":
|
| 58 |
+
wers = []
|
| 59 |
+
with mp.Pool(processes=len(gpus)) as pool:
|
| 60 |
+
args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
|
| 61 |
+
results = pool.map(run_asr_wer, args)
|
| 62 |
+
for wers_ in results:
|
| 63 |
+
wers.extend(wers_)
|
| 64 |
+
|
| 65 |
+
wer = round(np.mean(wers) * 100, 3)
|
| 66 |
+
print(f"\nTotal {len(wers)} samples")
|
| 67 |
+
print(f"WER : {wer}%")
|
| 68 |
+
|
| 69 |
+
# --------------------------- SIM ---------------------------
|
| 70 |
+
if eval_task == "sim":
|
| 71 |
+
sim_list = []
|
| 72 |
+
with mp.Pool(processes=len(gpus)) as pool:
|
| 73 |
+
args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
|
| 74 |
+
results = pool.map(run_sim, args)
|
| 75 |
+
for sim_ in results:
|
| 76 |
+
sim_list.extend(sim_)
|
| 77 |
+
|
| 78 |
+
sim = round(sum(sim_list) / len(sim_list), 3)
|
| 79 |
+
print(f"\nTotal {len(sim_list)} samples")
|
| 80 |
+
print(f"SIM : {sim}")
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
if __name__ == "__main__":
|
| 84 |
+
main()
|
src/f5_tts/eval/eval_seedtts_testset.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
| 2 |
|
| 3 |
import sys
|
| 4 |
import os
|
|
|
|
| 5 |
|
| 6 |
sys.path.append(os.getcwd())
|
| 7 |
|
|
@@ -19,57 +20,65 @@ from f5_tts.eval.utils_eval import (
|
|
| 19 |
rel_path = str(files("f5_tts").joinpath("../../"))
|
| 20 |
|
| 21 |
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
import sys
|
| 4 |
import os
|
| 5 |
+
import argparse
|
| 6 |
|
| 7 |
sys.path.append(os.getcwd())
|
| 8 |
|
|
|
|
| 20 |
rel_path = str(files("f5_tts").joinpath("../../"))
|
| 21 |
|
| 22 |
|
| 23 |
+
def get_args():
|
| 24 |
+
parser = argparse.ArgumentParser()
|
| 25 |
+
parser.add_argument("-e", "--eval_task", type=str, default="wer", choices=["sim", "wer"])
|
| 26 |
+
parser.add_argument("-l", "--lang", type=str, default="en", choices=["zh", "en"])
|
| 27 |
+
parser.add_argument("-g", "--gen_wav_dir", type=str, required=True)
|
| 28 |
+
parser.add_argument("-n", "--gpu_nums", type=int, default=8, help="Number of GPUs to use")
|
| 29 |
+
parser.add_argument("--local", action="store_true", help="Use local custom checkpoint directory")
|
| 30 |
+
return parser.parse_args()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def main():
|
| 34 |
+
args = get_args()
|
| 35 |
+
eval_task = args.eval_task
|
| 36 |
+
lang = args.lang
|
| 37 |
+
gen_wav_dir = args.gen_wav_dir
|
| 38 |
+
metalst = rel_path + f"/data/seedtts_testset/{lang}/meta.lst" # seed-tts testset
|
| 39 |
+
|
| 40 |
+
# NOTE. paraformer-zh result will be slightly different according to the number of gpus, cuz batchsize is different
|
| 41 |
+
# zh 1.254 seems a result of 4 workers wer_seed_tts
|
| 42 |
+
gpus = list(range(args.gpu_nums))
|
| 43 |
+
test_set = get_seed_tts_test(metalst, gen_wav_dir, gpus)
|
| 44 |
+
|
| 45 |
+
local = args.local
|
| 46 |
+
if local: # use local custom checkpoint dir
|
| 47 |
+
if lang == "zh":
|
| 48 |
+
asr_ckpt_dir = "../checkpoints/funasr" # paraformer-zh dir under funasr
|
| 49 |
+
elif lang == "en":
|
| 50 |
+
asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
|
| 51 |
+
else:
|
| 52 |
+
asr_ckpt_dir = "" # auto download to cache dir
|
| 53 |
+
wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
|
| 54 |
+
|
| 55 |
+
# --------------------------- WER ---------------------------
|
| 56 |
+
|
| 57 |
+
if eval_task == "wer":
|
| 58 |
+
wers = []
|
| 59 |
+
with mp.Pool(processes=len(gpus)) as pool:
|
| 60 |
+
args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
|
| 61 |
+
results = pool.map(run_asr_wer, args)
|
| 62 |
+
for wers_ in results:
|
| 63 |
+
wers.extend(wers_)
|
| 64 |
+
|
| 65 |
+
wer = round(np.mean(wers) * 100, 3)
|
| 66 |
+
print(f"\nTotal {len(wers)} samples")
|
| 67 |
+
print(f"WER : {wer}%")
|
| 68 |
+
|
| 69 |
+
# --------------------------- SIM ---------------------------
|
| 70 |
+
if eval_task == "sim":
|
| 71 |
+
sim_list = []
|
| 72 |
+
with mp.Pool(processes=len(gpus)) as pool:
|
| 73 |
+
args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
|
| 74 |
+
results = pool.map(run_sim, args)
|
| 75 |
+
for sim_ in results:
|
| 76 |
+
sim_list.extend(sim_)
|
| 77 |
+
|
| 78 |
+
sim = round(sum(sim_list) / len(sim_list), 3)
|
| 79 |
+
print(f"\nTotal {len(sim_list)} samples")
|
| 80 |
+
print(f"SIM : {sim}")
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
if __name__ == "__main__":
|
| 84 |
+
main()
|
src/f5_tts/train/README.md
CHANGED
|
@@ -16,6 +16,9 @@ python src/f5_tts/train/datasets/prepare_wenetspeech4tts.py
|
|
| 16 |
|
| 17 |
# Prepare the LibriTTS dataset
|
| 18 |
python src/f5_tts/train/datasets/prepare_libritts.py
|
|
|
|
|
|
|
|
|
|
| 19 |
```
|
| 20 |
|
| 21 |
### 2. Create custom dataset with metadata.csv
|
|
@@ -35,7 +38,7 @@ Once your datasets are prepared, you can start the training process.
|
|
| 35 |
# setup accelerate config, e.g. use multi-gpu ddp, fp16
|
| 36 |
# will be to: ~/.cache/huggingface/accelerate/default_config.yaml
|
| 37 |
accelerate config
|
| 38 |
-
accelerate launch src/f5_tts/train/train.py
|
| 39 |
```
|
| 40 |
|
| 41 |
### 2. Finetuning practice
|
|
|
|
| 16 |
|
| 17 |
# Prepare the LibriTTS dataset
|
| 18 |
python src/f5_tts/train/datasets/prepare_libritts.py
|
| 19 |
+
|
| 20 |
+
# Prepare the LJSpeech dataset
|
| 21 |
+
python src/f5_tts/train/datasets/prepare_ljspeech.py
|
| 22 |
```
|
| 23 |
|
| 24 |
### 2. Create custom dataset with metadata.csv
|
|
|
|
| 38 |
# setup accelerate config, e.g. use multi-gpu ddp, fp16
|
| 39 |
# will be to: ~/.cache/huggingface/accelerate/default_config.yaml
|
| 40 |
accelerate config
|
| 41 |
+
accelerate launch src/f5_tts/train/train.py --config-name F5TTS_Base_train.yaml # F5TTS_Base_train.yaml | E2TTS_Base_train.yaml
|
| 42 |
```
|
| 43 |
|
| 44 |
### 2. Finetuning practice
|
src/f5_tts/train/datasets/prepare_ljspeech.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
sys.path.append(os.getcwd())
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
from importlib.resources import files
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
import soundfile as sf
|
| 11 |
+
from datasets.arrow_writer import ArrowWriter
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def main():
|
| 15 |
+
result = []
|
| 16 |
+
duration_list = []
|
| 17 |
+
text_vocab_set = set()
|
| 18 |
+
|
| 19 |
+
with open(meta_info, "r") as f:
|
| 20 |
+
lines = f.readlines()
|
| 21 |
+
for line in tqdm(lines):
|
| 22 |
+
uttr, text, norm_text = line.split("|")
|
| 23 |
+
wav_path = Path(dataset_dir) / "wavs" / f"{uttr}.wav"
|
| 24 |
+
duration = sf.info(wav_path).duration
|
| 25 |
+
if duration < 0.4 or duration > 30:
|
| 26 |
+
continue
|
| 27 |
+
result.append({"audio_path": str(wav_path), "text": norm_text, "duration": duration})
|
| 28 |
+
duration_list.append(duration)
|
| 29 |
+
text_vocab_set.update(list(norm_text))
|
| 30 |
+
|
| 31 |
+
# save preprocessed dataset to disk
|
| 32 |
+
if not os.path.exists(f"{save_dir}"):
|
| 33 |
+
os.makedirs(f"{save_dir}")
|
| 34 |
+
print(f"\nSaving to {save_dir} ...")
|
| 35 |
+
|
| 36 |
+
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
|
| 37 |
+
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
| 38 |
+
writer.write(line)
|
| 39 |
+
|
| 40 |
+
# dup a json separately saving duration in case for DynamicBatchSampler ease
|
| 41 |
+
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
|
| 42 |
+
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
| 43 |
+
|
| 44 |
+
# vocab map, i.e. tokenizer
|
| 45 |
+
# add alphabets and symbols (optional, if plan to ft on de/fr etc.)
|
| 46 |
+
with open(f"{save_dir}/vocab.txt", "w") as f:
|
| 47 |
+
for vocab in sorted(text_vocab_set):
|
| 48 |
+
f.write(vocab + "\n")
|
| 49 |
+
|
| 50 |
+
print(f"\nFor {dataset_name}, sample count: {len(result)}")
|
| 51 |
+
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
|
| 52 |
+
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
if __name__ == "__main__":
|
| 56 |
+
tokenizer = "char" # "pinyin" | "char"
|
| 57 |
+
|
| 58 |
+
dataset_dir = "<SOME_PATH>/LJSpeech-1.1"
|
| 59 |
+
dataset_name = f"LJSpeech_{tokenizer}"
|
| 60 |
+
meta_info = os.path.join(dataset_dir, "metadata.csv")
|
| 61 |
+
save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}"
|
| 62 |
+
print(f"\nPrepare for {dataset_name}, will save to {save_dir}\n")
|
| 63 |
+
|
| 64 |
+
main()
|
src/f5_tts/train/train.py
CHANGED
|
@@ -1,100 +1,71 @@
|
|
| 1 |
# training script.
|
| 2 |
-
|
| 3 |
from importlib.resources import files
|
| 4 |
|
|
|
|
|
|
|
| 5 |
from f5_tts.model import CFM, DiT, Trainer, UNetT
|
| 6 |
from f5_tts.model.dataset import load_dataset
|
| 7 |
from f5_tts.model.utils import get_tokenizer
|
| 8 |
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
target_sample_rate = 24000
|
| 12 |
-
n_mel_channels = 100
|
| 13 |
-
hop_length = 256
|
| 14 |
-
win_length = 1024
|
| 15 |
-
n_fft = 1024
|
| 16 |
-
mel_spec_type = "vocos" # 'vocos' or 'bigvgan'
|
| 17 |
-
|
| 18 |
-
tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
|
| 19 |
-
tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
|
| 20 |
-
dataset_name = "Emilia_ZH_EN"
|
| 21 |
-
|
| 22 |
-
# -------------------------- Training Settings -------------------------- #
|
| 23 |
-
|
| 24 |
-
exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
|
| 25 |
-
|
| 26 |
-
learning_rate = 7.5e-5
|
| 27 |
-
|
| 28 |
-
batch_size_per_gpu = 38400 # 8 GPUs, 8 * 38400 = 307200
|
| 29 |
-
batch_size_type = "frame" # "frame" or "sample"
|
| 30 |
-
max_samples = 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
| 31 |
-
grad_accumulation_steps = 1 # note: updates = steps / grad_accumulation_steps
|
| 32 |
-
max_grad_norm = 1.0
|
| 33 |
|
| 34 |
-
epochs = 11 # use linear decay, thus epochs control the slope
|
| 35 |
-
num_warmup_updates = 20000 # warmup steps
|
| 36 |
-
save_per_updates = 50000 # save checkpoint per steps
|
| 37 |
-
last_per_steps = 5000 # save last checkpoint per steps
|
| 38 |
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
elif exp_name == "E2TTS_Base":
|
| 45 |
-
wandb_resume_id = None
|
| 46 |
-
model_cls = UNetT
|
| 47 |
-
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
# ----------------------------------------------------------------------- #
|
| 51 |
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
tokenizer_path = tokenizer_path
|
| 56 |
else:
|
| 57 |
-
tokenizer_path =
|
| 58 |
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
|
| 59 |
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
mel_spec_type=mel_spec_type,
|
| 67 |
-
)
|
| 68 |
|
| 69 |
model = CFM(
|
| 70 |
-
transformer=model_cls(**
|
| 71 |
-
mel_spec_kwargs=
|
| 72 |
vocab_char_map=vocab_char_map,
|
| 73 |
)
|
| 74 |
|
|
|
|
| 75 |
trainer = Trainer(
|
| 76 |
model,
|
| 77 |
-
epochs,
|
| 78 |
-
learning_rate,
|
| 79 |
-
num_warmup_updates=num_warmup_updates,
|
| 80 |
-
save_per_updates=save_per_updates,
|
| 81 |
-
checkpoint_path=str(files("f5_tts").joinpath(f"../../ckpts
|
| 82 |
-
batch_size=batch_size_per_gpu,
|
| 83 |
-
batch_size_type=batch_size_type,
|
| 84 |
-
max_samples=max_samples,
|
| 85 |
-
grad_accumulation_steps=grad_accumulation_steps,
|
| 86 |
-
max_grad_norm=max_grad_norm,
|
|
|
|
| 87 |
wandb_project="CFM-TTS",
|
| 88 |
wandb_run_name=exp_name,
|
| 89 |
wandb_resume_id=wandb_resume_id,
|
| 90 |
-
last_per_steps=last_per_steps,
|
| 91 |
log_samples=True,
|
|
|
|
| 92 |
mel_spec_type=mel_spec_type,
|
|
|
|
|
|
|
| 93 |
)
|
| 94 |
|
| 95 |
-
train_dataset = load_dataset(
|
| 96 |
trainer.train(
|
| 97 |
train_dataset,
|
|
|
|
| 98 |
resumable_with_seed=666, # seed for shuffling dataset
|
| 99 |
)
|
| 100 |
|
|
|
|
| 1 |
# training script.
|
| 2 |
+
import os
|
| 3 |
from importlib.resources import files
|
| 4 |
|
| 5 |
+
import hydra
|
| 6 |
+
|
| 7 |
from f5_tts.model import CFM, DiT, Trainer, UNetT
|
| 8 |
from f5_tts.model.dataset import load_dataset
|
| 9 |
from f5_tts.model.utils import get_tokenizer
|
| 10 |
|
| 11 |
+
os.chdir(str(files("f5_tts").joinpath("../..")))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
+
@hydra.main(version_base="1.3", config_path=str(files("f5_tts").joinpath("configs")), config_name=None)
|
| 15 |
+
def main(cfg):
|
| 16 |
+
tokenizer = cfg.model.tokenizer
|
| 17 |
+
mel_spec_type = cfg.model.mel_spec.mel_spec_type
|
| 18 |
+
exp_name = f"{cfg.model.name}_{mel_spec_type}_{cfg.model.tokenizer}_{cfg.datasets.name}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
+
# set text tokenizer
|
| 21 |
+
if tokenizer != "custom":
|
| 22 |
+
tokenizer_path = cfg.datasets.name
|
|
|
|
| 23 |
else:
|
| 24 |
+
tokenizer_path = cfg.model.tokenizer_path
|
| 25 |
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
|
| 26 |
|
| 27 |
+
# set model
|
| 28 |
+
if "F5TTS" in cfg.model.name:
|
| 29 |
+
model_cls = DiT
|
| 30 |
+
elif "E2TTS" in cfg.model.name:
|
| 31 |
+
model_cls = UNetT
|
| 32 |
+
wandb_resume_id = None
|
|
|
|
|
|
|
| 33 |
|
| 34 |
model = CFM(
|
| 35 |
+
transformer=model_cls(**cfg.model.arch, text_num_embeds=vocab_size, mel_dim=cfg.model.mel_spec.n_mel_channels),
|
| 36 |
+
mel_spec_kwargs=cfg.model.mel_spec,
|
| 37 |
vocab_char_map=vocab_char_map,
|
| 38 |
)
|
| 39 |
|
| 40 |
+
# init trainer
|
| 41 |
trainer = Trainer(
|
| 42 |
model,
|
| 43 |
+
epochs=cfg.optim.epochs,
|
| 44 |
+
learning_rate=cfg.optim.learning_rate,
|
| 45 |
+
num_warmup_updates=cfg.optim.num_warmup_updates,
|
| 46 |
+
save_per_updates=cfg.ckpts.save_per_updates,
|
| 47 |
+
checkpoint_path=str(files("f5_tts").joinpath(f"../../{cfg.ckpts.save_dir}")),
|
| 48 |
+
batch_size=cfg.datasets.batch_size_per_gpu,
|
| 49 |
+
batch_size_type=cfg.datasets.batch_size_type,
|
| 50 |
+
max_samples=cfg.datasets.max_samples,
|
| 51 |
+
grad_accumulation_steps=cfg.optim.grad_accumulation_steps,
|
| 52 |
+
max_grad_norm=cfg.optim.max_grad_norm,
|
| 53 |
+
logger=cfg.ckpts.logger,
|
| 54 |
wandb_project="CFM-TTS",
|
| 55 |
wandb_run_name=exp_name,
|
| 56 |
wandb_resume_id=wandb_resume_id,
|
| 57 |
+
last_per_steps=cfg.ckpts.last_per_steps,
|
| 58 |
log_samples=True,
|
| 59 |
+
bnb_optimizer=cfg.optim.bnb_optimizer,
|
| 60 |
mel_spec_type=mel_spec_type,
|
| 61 |
+
is_local_vocoder=cfg.model.mel_spec.is_local_vocoder,
|
| 62 |
+
local_vocoder_path=cfg.model.mel_spec.local_vocoder_path,
|
| 63 |
)
|
| 64 |
|
| 65 |
+
train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec)
|
| 66 |
trainer.train(
|
| 67 |
train_dataset,
|
| 68 |
+
num_workers=cfg.datasets.num_workers,
|
| 69 |
resumable_with_seed=666, # seed for shuffling dataset
|
| 70 |
)
|
| 71 |
|