Spaces:
Running
Running
| import gradio as gr | |
| import soundfile as sf | |
| import torch | |
| import numpy as np | |
| from pathlib import Path | |
| from transformers import AutoProcessor, AutoModel | |
| import tempfile | |
| import os | |
| import spaces | |
| import shutil | |
| # Import helper functions from your existing code | |
| from typing import List | |
| def smart_text_split_arabic(text: str, max_length: int = 300) -> List[str]: | |
| """Intelligently split Arabic text into chunks while preserving context.""" | |
| if len(text) <= max_length: | |
| return [text] | |
| chunks = [] | |
| remaining_text = text.strip() | |
| while remaining_text: | |
| if len(remaining_text) <= max_length: | |
| chunks.append(remaining_text) | |
| break | |
| chunk = remaining_text[:max_length] | |
| split_point = -1 | |
| # Priority 1: Sentence endings | |
| sentence_endings = ['.', '!', '?', '۔'] | |
| for i in range(len(chunk) - 1, max(0, max_length - 100), -1): | |
| if chunk[i] in sentence_endings: | |
| if i == len(chunk) - 1 or chunk[i + 1] == ' ': | |
| split_point = i + 1 | |
| break | |
| # Priority 2: Arabic clause separators | |
| if split_point == -1: | |
| arabic_separators = ['،', '؛', ':', ';', ','] | |
| for i in range(len(chunk) - 1, max(0, max_length - 50), -1): | |
| if chunk[i] in arabic_separators: | |
| if i == len(chunk) - 1 or chunk[i + 1] == ' ': | |
| split_point = i + 1 | |
| break | |
| # Priority 3: Word boundaries | |
| if split_point == -1: | |
| for i in range(len(chunk) - 1, max(0, max_length - 30), -1): | |
| if chunk[i] == ' ': | |
| split_point = i + 1 | |
| break | |
| if split_point == -1: | |
| split_point = max_length | |
| current_chunk = remaining_text[:split_point].strip() | |
| if current_chunk: | |
| chunks.append(current_chunk) | |
| remaining_text = remaining_text[split_point:].strip() | |
| return chunks | |
| def apply_crossfade(audio1: np.ndarray, audio2: np.ndarray, | |
| fade_duration: float = 0.1, sample_rate: int = 24000) -> np.ndarray: | |
| """Apply crossfade between two audio segments.""" | |
| fade_samples = int(fade_duration * sample_rate) | |
| fade_samples = min(fade_samples, len(audio1), len(audio2)) | |
| if fade_samples <= 0: | |
| return np.concatenate([audio1, audio2]) | |
| fade_out = np.linspace(1.0, 0.0, fade_samples) | |
| fade_in = np.linspace(0.0, 1.0, fade_samples) | |
| audio1_faded = audio1.copy() | |
| audio2_faded = audio2.copy() | |
| audio1_faded[-fade_samples:] *= fade_out | |
| audio2_faded[:fade_samples] *= fade_in | |
| overlap = audio1_faded[-fade_samples:] + audio2_faded[:fade_samples] | |
| result = np.concatenate([ | |
| audio1_faded[:-fade_samples], | |
| overlap, | |
| audio2_faded[fade_samples:] | |
| ]) | |
| return result | |
| def normalize_audio(audio: np.ndarray, target_rms: float = 0.1) -> np.ndarray: | |
| """Normalize audio to target RMS level.""" | |
| if len(audio) == 0: | |
| return audio | |
| current_rms = np.sqrt(np.mean(audio ** 2)) | |
| if current_rms > 1e-6: | |
| scaling_factor = target_rms / current_rms | |
| return audio * scaling_factor | |
| return audio | |
| def remove_silence(audio: np.ndarray, sample_rate: int = 24000, | |
| silence_threshold: float = 0.01, min_silence_duration: float = 0.5) -> np.ndarray: | |
| """Remove long silences from audio.""" | |
| if len(audio) == 0: | |
| return audio | |
| frame_size = int(0.05 * sample_rate) | |
| min_silence_frames = int(min_silence_duration / 0.05) | |
| frames = [] | |
| for i in range(0, len(audio), frame_size): | |
| frame = audio[i:i + frame_size] | |
| if len(frame) < frame_size: | |
| frames.append(frame) | |
| break | |
| rms = np.sqrt(np.mean(frame ** 2)) | |
| frames.append(frame if rms > silence_threshold else None) | |
| result_frames = [] | |
| silence_count = 0 | |
| for frame in frames: | |
| if frame is None: | |
| silence_count += 1 | |
| else: | |
| if silence_count > 0: | |
| if silence_count >= min_silence_frames: | |
| for _ in range(min(2, silence_count)): | |
| result_frames.append(np.zeros(frame_size, dtype=np.float32)) | |
| else: | |
| for _ in range(silence_count): | |
| result_frames.append(np.zeros(frame_size, dtype=np.float32)) | |
| result_frames.append(frame) | |
| silence_count = 0 | |
| if not result_frames: | |
| return np.array([], dtype=np.float32) | |
| return np.concatenate(result_frames) | |
| # Global model instance | |
| model_cache = {} | |
| def load_model(model_id: str = "IbrahimSalah/Arabic-TTS-Spark"): | |
| """Load the TTS model (cached).""" | |
| if "model" not in model_cache: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Loading model on {device}...") | |
| processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) | |
| model = AutoModel.from_pretrained(model_id, trust_remote_code=True).eval().to(device) | |
| processor.model = model | |
| model_cache["model"] = model | |
| model_cache["processor"] = processor | |
| model_cache["device"] = device | |
| print("Model loaded successfully!") | |
| return model_cache["model"], model_cache["processor"], model_cache["device"] | |
| # Request GPU for 120 seconds | |
| def generate_speech( | |
| text: str, | |
| reference_audio, | |
| reference_transcript: str, | |
| temperature: float = 0.8, | |
| top_p: float = 0.95, | |
| max_chunk_length: int = 300, | |
| crossfade_duration: float = 0.08, | |
| progress=gr.Progress() | |
| ): | |
| """Generate speech from text using Spark TTS.""" | |
| try: | |
| # Load model | |
| progress(0.1, desc="Loading model...") | |
| model, processor, device = load_model() | |
| # Validate inputs | |
| if not text.strip(): | |
| return None, "❌ Please enter text to synthesize." | |
| if reference_audio is None: | |
| return None, "❌ Please upload a reference audio file." | |
| if not reference_transcript.strip(): | |
| return None, "❌ Please enter the reference transcript." | |
| # Split text into chunks | |
| progress(0.2, desc="Splitting text...") | |
| text_chunks = smart_text_split_arabic(text, max_chunk_length) | |
| audio_segments = [] | |
| sample_rate = None | |
| # Generate audio for each chunk | |
| for i, chunk in enumerate(text_chunks): | |
| progress(0.2 + (0.6 * (i / len(text_chunks))), desc=f"Generating chunk {i+1}/{len(text_chunks)}...") | |
| inputs = processor( | |
| text=chunk.lower(), | |
| prompt_speech_path=reference_audio, | |
| prompt_text=reference_transcript, | |
| return_tensors="pt" | |
| ).to(device) | |
| global_tokens_prompt = inputs.pop("global_token_ids_prompt", None) | |
| with torch.no_grad(): | |
| output_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=8000, | |
| do_sample=True, | |
| temperature=temperature, | |
| top_k=50, | |
| top_p=top_p, | |
| eos_token_id=processor.tokenizer.eos_token_id, | |
| pad_token_id=processor.tokenizer.pad_token_id | |
| ) | |
| output = processor.decode( | |
| generated_ids=output_ids, | |
| global_token_ids_prompt=global_tokens_prompt, | |
| input_ids_len=inputs["input_ids"].shape[-1] | |
| ) | |
| audio = output["audio"] | |
| if isinstance(audio, torch.Tensor): | |
| audio = audio.cpu().numpy() | |
| if sample_rate is None: | |
| sample_rate = output["sampling_rate"] | |
| # Post-process | |
| audio = normalize_audio(audio, target_rms=0.1) | |
| audio = remove_silence(audio, sample_rate) | |
| if len(audio) > 0: | |
| audio_segments.append(audio) | |
| if not audio_segments: | |
| return None, "❌ No audio was generated." | |
| # Concatenate segments | |
| progress(0.9, desc="Concatenating audio...") | |
| final_audio = audio_segments[0] | |
| for i in range(1, len(audio_segments)): | |
| final_audio = apply_crossfade( | |
| final_audio, audio_segments[i], | |
| fade_duration=crossfade_duration, | |
| sample_rate=sample_rate | |
| ) | |
| # Final normalization | |
| final_audio = normalize_audio(final_audio, target_rms=0.1) | |
| # Save to temporary file | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: | |
| sf.write(tmp_file.name, final_audio, sample_rate) | |
| output_path = tmp_file.name | |
| duration = len(final_audio) / sample_rate | |
| status = f"✅ Generated {duration:.2f}s audio from {len(text_chunks)} chunks" | |
| progress(1.0, desc="Complete!") | |
| return output_path, status | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"❌ Error: {str(e)}\n{traceback.format_exc()}" | |
| print(error_msg) | |
| return None, error_msg | |
| # Default examples | |
| DEFAULT_REFERENCE_TEXT = "لَا يَمُرُّ يَوْمٌ إِلَّا وَأَسْتَقْبِلُ عِدَّةَ رَسَائِلَ، تَتَضَمَّنُ أَسْئِلَةً مُلِحَّةْ." | |
| DEFAULT_TEXT = "تُسَاهِمُ التِّقْنِيَّاتُ الْحَدِيثَةُ فِي تَسْهِيلِ حَيَاةِ الْإِنْسَانِ، وَذَلِكَ مِنْ خِلَالِ تَطْوِيرِ أَنْظِمَةٍ ذَكِيَّةٍ تَعْتَمِدُ عَلَى الذَّكَاءِ الِاصْطِنَاعِيِّ." | |
| # Path to default reference audio | |
| DEFAULT_REFERENCE_AUDIO = "reference.wav" | |
| # Create Gradio interface | |
| with gr.Blocks(title="Arabic TTS - Spark", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # 🎙️ Arabic Text-to-Speech | Spark Model | |
| High-quality Arabic TTS with voice cloning. **Diacritized text (تشكيل) required.** | |
| **Model:** [IbrahimSalah/Arabic-TTS-Spark](https://huggingface.co/IbrahimSalah/Arabic-TTS-Spark) | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| text_input = gr.Textbox( | |
| label="📝 Text to Synthesize (Arabic with Tashkeel)", | |
| placeholder="أَدْخِلْ نَصًّا عَرَبِيًّا مُشَكَّلًا هُنَا...", | |
| lines=6, | |
| value=DEFAULT_TEXT | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("**🎵 Reference Audio**") | |
| reference_audio = gr.Audio( | |
| label="", | |
| type="filepath", | |
| value=DEFAULT_REFERENCE_AUDIO | |
| ) | |
| with gr.Column(): | |
| reference_transcript = gr.Textbox( | |
| label="📄 Reference Transcript (with Tashkeel)", | |
| placeholder="النص المقابل للصوت المرجعي...", | |
| lines=4, | |
| value=DEFAULT_REFERENCE_TEXT | |
| ) | |
| with gr.Accordion("⚙️ Advanced Settings", open=False): | |
| with gr.Row(): | |
| temperature = gr.Slider(0.1, 1.5, value=0.8, step=0.1, label="Temperature") | |
| top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top P") | |
| with gr.Row(): | |
| max_chunk = gr.Slider(100, 500, value=300, step=50, label="Max Chunk Length") | |
| crossfade = gr.Slider(0.01, 0.2, value=0.08, step=0.01, label="Crossfade (s)") | |
| generate_btn = gr.Button("🎤 Generate Speech", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| output_audio = gr.Audio(label="🔊 Generated Speech", type="filepath") | |
| status_text = gr.Textbox(label="Status", interactive=False, lines=2) | |
| gr.Markdown(""" | |
| ### ℹ️ Requirements | |
| - **Diacritized text is required** (تشكيل/تشكيل) | |
| - Reference audio: 5-30 seconds, clear speech | |
| - Use AI (ChatGPT/Claude) or [online tools](https://tahadz.com/mishkal) to add diacritics | |
| ### 🔗 Resources | |
| - [Model Card](https://huggingface.co/IbrahimSalah/Arabic-TTS-Spark) | |
| - [F5-TTS Arabic](https://huggingface.co/IbrahimSalah/Arabic-F5-TTS-v2) | |
| - [Report Issues](https://huggingface.co/IbrahimSalah/Arabic-TTS-Spark/discussions) | |
| """) | |
| # Examples | |
| with gr.Accordion("📚 Examples", open=False): | |
| gr.Examples( | |
| examples=[ | |
| [DEFAULT_TEXT, DEFAULT_REFERENCE_AUDIO, DEFAULT_REFERENCE_TEXT], | |
| ["السَّلَامُ عَلَيْكُمْ وَرَحْمَةُ اللَّهِ وَبَرَكَاتُهُ، كَيْفَ حَالُكَ الْيَوْمَ؟", DEFAULT_REFERENCE_AUDIO, DEFAULT_REFERENCE_TEXT], | |
| ["الذَّكَاءُ الِاصْطِنَاعِيُّ يُغَيِّرُ الْعَالَمَ بِسُرْعَةٍ كَبِيرَةٍ وَيُسَاهِمُ فِي تَطْوِيرِ حُلُولٍ مُبْتَكَرَةٍ.", DEFAULT_REFERENCE_AUDIO, DEFAULT_REFERENCE_TEXT] | |
| ], | |
| inputs=[text_input, reference_audio, reference_transcript] | |
| ) | |
| generate_btn.click( | |
| fn=generate_speech, | |
| inputs=[text_input, reference_audio, reference_transcript, temperature, top_p, max_chunk, crossfade], | |
| outputs=[output_audio, status_text] | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=20) | |
| demo.launch() | |