import os import sys import re import numpy as np import torch import soundfile as sf import spaces from config import models_path, results_path, sample_path, BASE_DIR from sentence_splitter import PersianSentenceSplitter from text_utils import convert_number_to_text encoder = None synthesizer = None vocoder = None sentence_splitter = None def load_models(): global encoder, synthesizer, vocoder, sentence_splitter try: sys.path.append(os.path.join(BASE_DIR, 'pmt2')) from encoder import inference as encoder_module from synthesizer.inference import Synthesizer from parallel_wavegan.utils import load_model as vocoder_hifigan global encoder encoder = encoder_module print("Loading encoder model...") encoder.load_model(os.path.join(models_path, 'encoder.pt')) print("Loading synthesizer model...") synthesizer = Synthesizer(os.path.join(models_path, 'synthesizer.pt')) print("Loading HiFiGAN vocoder...") vocoder = vocoder_hifigan(os.path.join(models_path, 'vocoder_HiFiGAN.pkl')) vocoder.remove_weight_norm() vocoder = vocoder.eval().to('cuda' if torch.cuda.is_available() else 'cpu') sentence_splitter = PersianSentenceSplitter(max_chars=150, min_chars=30) print("Models loaded successfully!") return True except Exception as e: import traceback print(f"Error loading models: {traceback.format_exc()}") return False def normalize_text_for_synthesis(text: str) -> str: text = text.replace('ك', 'ک').replace('ي', 'ی') text = text.replace('_', '\u200c') text = re.sub(r'\s+', ' ', text) text = text.strip() number_pattern = r'[۰-۹0-9٠-٩]+(?:[,،٬][۰-۹0-9٠-٩]+)*' def replace_number(match): num_str = match.group(0) try: return convert_number_to_text(num_str) except: return num_str text = re.sub(number_pattern, replace_number, text) return text def synthesize_segment(text_segment: str, embed: np.ndarray) -> np.ndarray: try: text_segment = normalize_text_for_synthesis(text_segment) specs = synthesizer.synthesize_spectrograms([text_segment], [embed]) spec = specs[0] x = torch.from_numpy(spec.T).to('cuda' if torch.cuda.is_available() else 'cpu') with torch.no_grad(): wav = vocoder.inference(x) wav = wav.cpu().numpy() if wav.ndim > 1: wav = wav.squeeze() return wav except Exception as e: import traceback print(f"Error synthesizing segment '{text_segment[:50]}...': {traceback.format_exc()}") return None def add_silence(duration_ms: int = 300) -> np.ndarray: sample_rate = synthesizer.sample_rate num_samples = int(sample_rate * duration_ms / 1000) return np.zeros(num_samples, dtype=np.float32) @spaces.GPU(duration=120) def generate_speech(text, reference_audio=None, add_pauses: bool = True): if not text or text.strip() == "": return None try: if reference_audio is None: ref_wav_path = sample_path else: ref_wav_path = os.path.join(results_path, "reference_audio.wav") sf.write(ref_wav_path, reference_audio[1], reference_audio[0]) print(f"Using reference audio: {ref_wav_path}") wav = synthesizer.load_preprocess_wav(ref_wav_path) encoder_wav = encoder.preprocess_wav(wav) embed, partial_embeds, _ = encoder.embed_utterance(encoder_wav, return_partials=True) text_segments = sentence_splitter.split(text) print(f"Split text into {len(text_segments)} segments:") for i, segment in enumerate(text_segments, 1): print(f" Segment {i}: {segment[:60]}{'...' if len(segment) > 60 else ''}") audio_segments = [] silence = add_silence(300) if add_pauses else None # 300ms pause for i, segment in enumerate(text_segments): print(f"Processing segment {i+1}/{len(text_segments)}...") segment_wav = synthesize_segment(segment, embed) if segment_wav is not None: segment_wav = segment_wav.flatten() if segment_wav.ndim > 1 else segment_wav audio_segments.append(segment_wav) if add_pauses and i < len(text_segments) - 1: audio_segments.append(silence) else: print(f"Warning: Failed to synthesize segment {i+1}") if not audio_segments: print("Error: No audio segments were generated successfully") return None audio_segments = [seg.flatten() if seg.ndim > 1 else seg for seg in audio_segments] final_wav = np.concatenate(audio_segments) final_wav = final_wav / np.abs(final_wav).max() * 0.97 output_filename = f"generated_{abs(hash(text)) % 100000}.wav" output_path = os.path.join(results_path, output_filename) sf.write(output_path, final_wav, synthesizer.sample_rate) print(f"✓ Successfully generated speech: {output_path}") print(f" Total duration: {len(final_wav) / synthesizer.sample_rate:.2f} seconds") return output_path except Exception as e: import traceback error_details = traceback.format_exc() print(f"Error generating speech: {error_details}") return None