Gapeleon commited on
Commit
5d5b597
·
1 Parent(s): 88ad001

Prepare HF Space

Browse files
Files changed (7) hide show
  1. README.md +17 -8
  2. launch.py +37 -0
  3. mira/__init__.py +1 -0
  4. mira/model.py +74 -0
  5. mira/utils.py +11 -0
  6. requirements.txt +11 -0
  7. web_ui.py +336 -0
README.md CHANGED
@@ -1,12 +1,21 @@
 
 
1
  ---
2
- title: Mira TTS
3
- emoji: 📊
4
- colorFrom: red
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 6.1.0
8
- app_file: app.py
9
  pinned: false
 
 
 
 
 
 
 
 
 
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ # MiraTTS
2
+
3
  ---
4
+ title: Mira-TTS
5
+ emoji: ~Z
6
+ colorFrom: yellow
7
+ colorTo: yellow
8
  sdk: gradio
9
+ sdk_version: 5.50.0
10
+ app_file: webui.py
11
  pinned: false
12
+ license: apache-2.0
13
+ short_description: (Unofficial) Gradio demo for MiraTTS
14
+ models:
15
+ - YatharthS/MiraTTS
16
+ tags:
17
+ - text-to-speech
18
+ - voice-cloning
19
+ - speech-synthesis
20
+ python_version: "3.12"
21
  ---
 
 
launch.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Launch script for MiraTTS Web Interface
4
+ Simple wrapper to start the web UI with common configurations
5
+ """
6
+
7
+ import subprocess
8
+ import sys
9
+ import argparse
10
+
11
+ def main():
12
+ parser = argparse.ArgumentParser(description="Launch MiraTTS Web Interface")
13
+ parser.add_argument("--port", type=int, default=7860, help="Port to run on")
14
+ parser.add_argument("--host", default="127.0.0.1", help="Host to bind to")
15
+ parser.add_argument("--share", action="store_true", help="Create public share link")
16
+ parser.add_argument("--model", default="YatharthS/MiraTTS", help="Model path or HF model ID")
17
+
18
+ args = parser.parse_args()
19
+
20
+ cmd = [
21
+ sys.executable, "web_ui.py",
22
+ "--server_name", args.host,
23
+ "--server_port", str(args.port),
24
+ "--model_dir", args.model
25
+ ]
26
+
27
+ if args.share:
28
+ cmd.append("--share")
29
+
30
+ print(f"Launching MiraTTS Web Interface...")
31
+ print(f"Model: {args.model}")
32
+ print(f"URL: http://{args.host}:{args.port}")
33
+
34
+ subprocess.run(cmd)
35
+
36
+ if __name__ == "__main__":
37
+ main()
mira/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
mira/model.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import torch
3
+ from itertools import cycle
4
+ from ncodec.codec import TTSCodec
5
+ from lmdeploy import pipeline, GenerationConfig, TurbomindEngineConfig
6
+
7
+ from mira.utils import clear_cache, split_text
8
+
9
+ class MiraTTS:
10
+
11
+ def __init__(self, model_dir="YatharthS/MiraTTS", tp=1, enable_prefix_caching=True, cache_max_entry_count=0.2):
12
+
13
+ backend_config = TurbomindEngineConfig(cache_max_entry_count=cache_max_entry_count, tp=tp, dtype='bfloat16', enable_prefix_caching=enable_prefix_caching)
14
+ self.pipe = pipeline(model_dir, backend_config=backend_config)
15
+ self.gen_config = GenerationConfig(top_p=0.95,
16
+ top_k=50,
17
+ temperature=0.8,
18
+ max_new_tokens=1024,
19
+ repetition_penalty=1.2,
20
+ do_sample=True,
21
+ min_p=0.05)
22
+ self.codec = TTSCodec()
23
+
24
+ def set_params(self, top_p=0.95, top_k=50, temperature=0.8, max_new_tokens=1024, repetition_penalty=1.2, min_p=0.05):
25
+ """sets sampling parameters for the llm"""
26
+
27
+ self.gen_config = GenerationConfig(top_p=top_p, top_k=top_k, temperature=temperature, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, min_p=min_p, do_sample=True)
28
+
29
+ def c_cache(self):
30
+ clear_cache()
31
+
32
+ def split_text(self, text):
33
+ return split_text(text)
34
+
35
+ def encode_audio(self, audio_file):
36
+ """encodes audio into context tokens"""
37
+
38
+ context_tokens = self.codec.encode(audio_file)
39
+ return context_tokens
40
+
41
+
42
+ def generate(self, text, context_tokens):
43
+ """generates speech from input text"""
44
+ formatted_prompt = self.codec.format_prompt(text, context_tokens, None)
45
+
46
+ response = self.pipe([formatted_prompt], gen_config=self.gen_config, do_preprocess=False)
47
+ audio = self.codec.decode(response[0].text, context_tokens)
48
+ return audio
49
+
50
+ def batch_generate(self, prompts, context_tokens):
51
+ """
52
+ Generates speech from text, for larger batch size
53
+
54
+ Args:
55
+ prompt (list): Input for tts model, list of prompts
56
+ voice (list): Description of voice, list of voices respective to prompt
57
+ """
58
+ formatted_prompts = []
59
+ for prompt, context_token in zip(prompts, cycle(context_tokens)):
60
+ formatted_prompt = self.codec.format_prompt(prompt, context_token, None)
61
+ formatted_prompts.append(formatted_prompt)
62
+
63
+ responses = self.pipe(formatted_prompts, gen_config=self.gen_config, do_preprocess=False)
64
+ generated_tokens = [response.text for response in responses]
65
+
66
+ audios = []
67
+ for generated_token, context_token in zip(generated_tokens, cycle(context_tokens)):
68
+ audio = self.codec.decode(generated_token, context_token)
69
+ audios.append(audio)
70
+ audios = torch.cat(audios, dim=0)
71
+
72
+ return audios
73
+
74
+
mira/utils.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import gc
3
+ import torch
4
+
5
+ def split_text(text):
6
+ sentences = re.split(r'(?<=[.!?])\s+', text)
7
+ return sentences
8
+
9
+ def clear_cache():
10
+ gc.collect()
11
+ torch.cuda.empty_cache()
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ lmdeploy
2
+ librosa
3
+ fastaudiosr @ git+https://github.com/ysharma3501/FlashSR.git
4
+ ncodec @ git+https://github.com/ysharma3501/FastBiCodec.git
5
+ einops
6
+ onnxruntime-gpu
7
+ soundfile
8
+ torch
9
+ torchaudio
10
+ transformers
11
+ omegaconf
web_ui.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import soundfile as sf
4
+ import logging
5
+ import argparse
6
+ import gradio as gr
7
+ from datetime import datetime
8
+ from mira.model import MiraTTS
9
+
10
+ MODEL = None
11
+
12
+ def initialize_model(model_dir="YatharthS/MiraTTS"):
13
+ """Load the MiraTTS model once at the beginning."""
14
+ logging.info(f"Loading MiraTTS model from: {model_dir}")
15
+ model = MiraTTS(model_dir)
16
+ return model
17
+
18
+ def generate_audio(text, prompt_audio_path):
19
+ """Generate audio from text using MiraTTS with voice cloning."""
20
+ global MODEL
21
+
22
+ if MODEL is None:
23
+ MODEL = initialize_model()
24
+
25
+ try:
26
+ # Encode the prompt audio
27
+ context_tokens = MODEL.encode_audio(prompt_audio_path)
28
+
29
+ # Generate audio
30
+ audio = MODEL.generate(text, context_tokens)
31
+
32
+ # Convert to numpy array if it's a tensor and handle dtype
33
+ if torch.is_tensor(audio):
34
+ audio = audio.cpu().numpy()
35
+
36
+ # Ensure correct dtype for soundfile (convert from float16 to float32)
37
+ if audio.dtype == 'float16':
38
+ audio = audio.astype('float32')
39
+ elif audio.dtype not in ['float32', 'float64', 'int16', 'int32']:
40
+ audio = audio.astype('float32')
41
+
42
+ return audio, 48000 # Return audio and sample rate
43
+ except Exception as e:
44
+ logging.error(f"Error during generation: {e}")
45
+ raise e
46
+
47
+ def run_tts(text, prompt_audio_path, save_dir="results"):
48
+ """Perform TTS inference and save the generated audio."""
49
+ logging.info(f"Saving audio to: {save_dir}")
50
+
51
+ # Ensure the save directory exists
52
+ os.makedirs(save_dir, exist_ok=True)
53
+
54
+ # Generate unique filename using timestamp
55
+ timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
56
+ save_path = os.path.join(save_dir, f"mira_tts_{timestamp}.wav")
57
+
58
+ logging.info("Starting MiraTTS inference...")
59
+
60
+ # Generate audio
61
+ audio, sample_rate = generate_audio(text, prompt_audio_path)
62
+
63
+ # Save audio file
64
+ sf.write(save_path, audio, samplerate=sample_rate)
65
+
66
+ logging.info(f"Audio saved at: {save_path}")
67
+ return save_path
68
+
69
+ def voice_clone_callback(text, prompt_audio_upload, prompt_audio_record):
70
+ """Gradio callback for voice cloning using MiraTTS."""
71
+ if not text.strip():
72
+ return None
73
+
74
+ # Use uploaded audio or recorded audio
75
+ prompt_audio = prompt_audio_upload if prompt_audio_upload else prompt_audio_record
76
+
77
+ if not prompt_audio:
78
+ return None
79
+
80
+ try:
81
+ audio_output_path = run_tts(text, prompt_audio)
82
+ return audio_output_path
83
+ except Exception as e:
84
+ logging.error(f"Error in voice cloning: {e}")
85
+ return None
86
+
87
+ def voice_creation_callback(text, temperature, top_p, top_k):
88
+ """Gradio callback for creating synthetic voice with custom parameters."""
89
+ if not text.strip():
90
+ return None
91
+
92
+ global MODEL
93
+
94
+ if MODEL is None:
95
+ MODEL = initialize_model()
96
+
97
+ try:
98
+ # Set custom generation parameters
99
+ MODEL.set_params(
100
+ temperature=temperature,
101
+ top_p=top_p,
102
+ top_k=top_k,
103
+ max_new_tokens=1024,
104
+ repetition_penalty=1.2
105
+ )
106
+
107
+ # Use a default voice context (you may want to provide default audio files)
108
+ # Check multiple possible paths for example audio
109
+ possible_paths = [
110
+ "/models3/src/MiraTTS/models/MiraTTS/example1.wav",
111
+ "models/MiraTTS/example1.wav",
112
+ "./models/MiraTTS/example1.wav"
113
+ ]
114
+
115
+ default_audio = None
116
+ for path in possible_paths:
117
+ if os.path.exists(path):
118
+ default_audio = path
119
+ break
120
+
121
+ if default_audio:
122
+ # Generate audio with dtype conversion
123
+ context_tokens = MODEL.encode_audio(default_audio)
124
+ audio = MODEL.generate(text, context_tokens)
125
+
126
+ # Handle tensor conversion and dtype
127
+ if torch.is_tensor(audio):
128
+ audio = audio.cpu().numpy()
129
+
130
+ # Ensure correct dtype for soundfile
131
+ if audio.dtype == 'float16':
132
+ audio = audio.astype('float32')
133
+ elif audio.dtype not in ['float32', 'float64', 'int16', 'int32']:
134
+ audio = audio.astype('float32')
135
+
136
+ # Save the audio
137
+ os.makedirs("results", exist_ok=True)
138
+ timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
139
+ save_path = os.path.join("results", f"mira_tts_creation_{timestamp}.wav")
140
+ sf.write(save_path, audio, samplerate=48000)
141
+
142
+ return save_path
143
+ else:
144
+ logging.warning("No default audio found for voice creation")
145
+ return None
146
+
147
+ except Exception as e:
148
+ logging.error(f"Error in voice creation: {e}")
149
+ return None
150
+
151
+ def build_ui():
152
+ """Build the Gradio interface similar to SparkTTS."""
153
+
154
+ with gr.Blocks(title="MiraTTS Web Interface") as demo:
155
+ # Title
156
+ gr.HTML('<h1 style="text-align: center;">MiraTTS - High Quality Voice Synthesis</h1>')
157
+
158
+ # Description
159
+ gr.Markdown("""
160
+ MiraTTS is a highly optimized Text-to-Speech model based on Spark-TTS with LMDeploy acceleration.
161
+ It provides over 100x realtime generation speed with high-quality 48kHz audio output.
162
+ """)
163
+
164
+ with gr.Tabs():
165
+ # Voice Clone Tab
166
+ with gr.TabItem("Voice Clone"):
167
+ gr.Markdown("### Clone any voice using a reference audio sample")
168
+
169
+ with gr.Row():
170
+ prompt_audio_upload = gr.Audio(
171
+ sources="upload",
172
+ type="filepath",
173
+ label="Upload Reference Audio (recommended: 3-30 seconds, 16kHz+)",
174
+ )
175
+ prompt_audio_record = gr.Audio(
176
+ sources="microphone",
177
+ type="filepath",
178
+ label="Record Reference Audio",
179
+ )
180
+
181
+ text_input = gr.Textbox(
182
+ label="Text to Synthesize",
183
+ lines=3,
184
+ placeholder="Enter the text you want to convert to speech...",
185
+ value="Hello! This is a demonstration of MiraTTS voice cloning capabilities."
186
+ )
187
+
188
+ with gr.Row():
189
+ clone_button = gr.Button("Generate Audio", variant="primary")
190
+ clear_button = gr.Button("Clear")
191
+
192
+ audio_output_clone = gr.Audio(
193
+ label="Generated Audio",
194
+ autoplay=True
195
+ )
196
+
197
+ clone_button.click(
198
+ voice_clone_callback,
199
+ inputs=[text_input, prompt_audio_upload, prompt_audio_record],
200
+ outputs=[audio_output_clone],
201
+ )
202
+
203
+ clear_button.click(
204
+ lambda: (None, None, "", None),
205
+ outputs=[prompt_audio_upload, prompt_audio_record, text_input, audio_output_clone]
206
+ )
207
+
208
+ # Voice Creation Tab
209
+ with gr.TabItem("Voice Creation"):
210
+ gr.Markdown("### Create synthetic voices with custom parameters")
211
+
212
+ with gr.Row():
213
+ with gr.Column():
214
+ text_input_creation = gr.Textbox(
215
+ label="Text to Synthesize",
216
+ lines=3,
217
+ placeholder="Enter text here...",
218
+ value="You can create customized voices by adjusting the generation parameters below."
219
+ )
220
+
221
+ with gr.Row():
222
+ temperature = gr.Slider(
223
+ minimum=0.1,
224
+ maximum=1.5,
225
+ step=0.1,
226
+ value=0.8,
227
+ label="Temperature (creativity)"
228
+ )
229
+ top_p = gr.Slider(
230
+ minimum=0.1,
231
+ maximum=1.0,
232
+ step=0.05,
233
+ value=0.95,
234
+ label="Top-p (nucleus sampling)"
235
+ )
236
+ top_k = gr.Slider(
237
+ minimum=1,
238
+ maximum=100,
239
+ step=1,
240
+ value=50,
241
+ label="Top-k (vocabulary size)"
242
+ )
243
+
244
+ with gr.Column():
245
+ create_button = gr.Button("Create Voice", variant="primary")
246
+ audio_output_creation = gr.Audio(
247
+ label="Generated Audio",
248
+ autoplay=True
249
+ )
250
+
251
+ create_button.click(
252
+ voice_creation_callback,
253
+ inputs=[text_input_creation, temperature, top_p, top_k],
254
+ outputs=[audio_output_creation],
255
+ )
256
+
257
+ # About Tab
258
+ with gr.TabItem("About"):
259
+ gr.Markdown("""
260
+ ## About MiraTTS
261
+
262
+ MiraTTS is an optimized version of Spark-TTS with the following features:
263
+
264
+ - **Ultra-fast generation**: Over 100x realtime speed using LMDeploy optimization
265
+ - **High quality**: Generates crisp 48kHz audio outputs
266
+ - **Memory efficient**: Works within 6GB VRAM
267
+ - **Low latency**: As low as 100ms generation time
268
+ - **Voice cloning**: Clone any voice from a short audio sample
269
+
270
+ ### Model Information
271
+ - Base model: Spark-TTS-0.5B
272
+ - Optimization: LMDeploy + FlashSR
273
+ - Sample rate: 48kHz
274
+ - Model size: ~500M parameters
275
+
276
+ ### Usage Tips
277
+ - For voice cloning, use clear audio samples between 3-30 seconds
278
+ - Ensure reference audio is at least 16kHz quality
279
+ - Longer text inputs may require more memory
280
+ - Adjust generation parameters for different voice styles
281
+ """)
282
+
283
+ return demo
284
+
285
+ def parse_arguments():
286
+ """Parse command-line arguments."""
287
+ parser = argparse.ArgumentParser(description="MiraTTS Gradio Web Interface")
288
+ parser.add_argument(
289
+ "--model_dir",
290
+ type=str,
291
+ default="YatharthS/MiraTTS",
292
+ help="Path to the MiraTTS model directory or HuggingFace model ID"
293
+ )
294
+ parser.add_argument(
295
+ "--server_name",
296
+ type=str,
297
+ default="127.0.0.1",
298
+ help="Server host/IP for Gradio app"
299
+ )
300
+ parser.add_argument(
301
+ "--server_port",
302
+ type=int,
303
+ default=7860,
304
+ help="Server port for Gradio app"
305
+ )
306
+ parser.add_argument(
307
+ "--share",
308
+ action="store_true",
309
+ help="Create a public shareable link"
310
+ )
311
+ return parser.parse_args()
312
+
313
+ if __name__ == "__main__":
314
+ # Configure logging
315
+ logging.basicConfig(
316
+ level=logging.INFO,
317
+ format='%(asctime)s - %(levelname)s - %(message)s'
318
+ )
319
+
320
+ # Parse arguments
321
+ args = parse_arguments()
322
+
323
+ # Initialize model
324
+ logging.info("Initializing MiraTTS model...")
325
+ MODEL = initialize_model(args.model_dir)
326
+
327
+ # Build and launch interface
328
+ logging.info("Building Gradio interface...")
329
+ demo = build_ui()
330
+
331
+ logging.info(f"Launching web interface on {args.server_name}:{args.server_port}")
332
+ demo.launch(
333
+ server_name=args.server_name,
334
+ server_port=args.server_port,
335
+ share=args.share
336
+ )