mana-tts / synthesis.py
abreza's picture
feat: improved number handling and audio processing
da2ee9a
import os
import sys
import re
import numpy as np
import torch
import soundfile as sf
import spaces
from config import models_path, results_path, sample_path, BASE_DIR
from sentence_splitter import PersianSentenceSplitter
from text_utils import convert_number_to_text
encoder = None
synthesizer = None
vocoder = None
sentence_splitter = None
def load_models():
global encoder, synthesizer, vocoder, sentence_splitter
try:
sys.path.append(os.path.join(BASE_DIR, 'pmt2'))
from encoder import inference as encoder_module
from synthesizer.inference import Synthesizer
from parallel_wavegan.utils import load_model as vocoder_hifigan
global encoder
encoder = encoder_module
print("Loading encoder model...")
encoder.load_model(os.path.join(models_path, 'encoder.pt'))
print("Loading synthesizer model...")
synthesizer = Synthesizer(os.path.join(models_path, 'synthesizer.pt'))
print("Loading HiFiGAN vocoder...")
vocoder = vocoder_hifigan(os.path.join(models_path, 'vocoder_HiFiGAN.pkl'))
vocoder.remove_weight_norm()
vocoder = vocoder.eval().to('cuda' if torch.cuda.is_available() else 'cpu')
sentence_splitter = PersianSentenceSplitter(max_chars=150, min_chars=30)
print("Models loaded successfully!")
return True
except Exception as e:
import traceback
print(f"Error loading models: {traceback.format_exc()}")
return False
def normalize_text_for_synthesis(text: str) -> str:
text = text.replace('ك', 'ک').replace('ي', 'ی')
text = text.replace('_', '\u200c')
text = re.sub(r'\s+', ' ', text)
text = text.strip()
number_pattern = r'[۰-۹0-9٠-٩]+(?:[,،٬][۰-۹0-9٠-٩]+)*'
def replace_number(match):
num_str = match.group(0)
try:
return convert_number_to_text(num_str)
except:
return num_str
text = re.sub(number_pattern, replace_number, text)
return text
def synthesize_segment(text_segment: str, embed: np.ndarray) -> np.ndarray:
try:
text_segment = normalize_text_for_synthesis(text_segment)
specs = synthesizer.synthesize_spectrograms([text_segment], [embed])
spec = specs[0]
x = torch.from_numpy(spec.T).to('cuda' if torch.cuda.is_available() else 'cpu')
with torch.no_grad():
wav = vocoder.inference(x)
wav = wav.cpu().numpy()
if wav.ndim > 1:
wav = wav.squeeze()
return wav
except Exception as e:
import traceback
print(f"Error synthesizing segment '{text_segment[:50]}...': {traceback.format_exc()}")
return None
def add_silence(duration_ms: int = 300) -> np.ndarray:
sample_rate = synthesizer.sample_rate
num_samples = int(sample_rate * duration_ms / 1000)
return np.zeros(num_samples, dtype=np.float32)
@spaces.GPU(duration=120)
def generate_speech(text, reference_audio=None, add_pauses: bool = True):
if not text or text.strip() == "":
return None
try:
if reference_audio is None:
ref_wav_path = sample_path
else:
ref_wav_path = os.path.join(results_path, "reference_audio.wav")
sf.write(ref_wav_path, reference_audio[1], reference_audio[0])
print(f"Using reference audio: {ref_wav_path}")
wav = synthesizer.load_preprocess_wav(ref_wav_path)
encoder_wav = encoder.preprocess_wav(wav)
embed, partial_embeds, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
text_segments = sentence_splitter.split(text)
print(f"Split text into {len(text_segments)} segments:")
for i, segment in enumerate(text_segments, 1):
print(f" Segment {i}: {segment[:60]}{'...' if len(segment) > 60 else ''}")
audio_segments = []
silence = add_silence(300) if add_pauses else None # 300ms pause
for i, segment in enumerate(text_segments):
print(f"Processing segment {i+1}/{len(text_segments)}...")
segment_wav = synthesize_segment(segment, embed)
if segment_wav is not None:
segment_wav = segment_wav.flatten() if segment_wav.ndim > 1 else segment_wav
audio_segments.append(segment_wav)
if add_pauses and i < len(text_segments) - 1:
audio_segments.append(silence)
else:
print(f"Warning: Failed to synthesize segment {i+1}")
if not audio_segments:
print("Error: No audio segments were generated successfully")
return None
audio_segments = [seg.flatten() if seg.ndim > 1 else seg for seg in audio_segments]
final_wav = np.concatenate(audio_segments)
final_wav = final_wav / np.abs(final_wav).max() * 0.97
output_filename = f"generated_{abs(hash(text)) % 100000}.wav"
output_path = os.path.join(results_path, output_filename)
sf.write(output_path, final_wav, synthesizer.sample_rate)
print(f"✓ Successfully generated speech: {output_path}")
print(f" Total duration: {len(final_wav) / synthesizer.sample_rate:.2f} seconds")
return output_path
except Exception as e:
import traceback
error_details = traceback.format_exc()
print(f"Error generating speech: {error_details}")
return None