//! BigVGAN vocoder implementation //! //! High-quality neural vocoder for mel-spectrogram to waveform conversion use crate::{Error, Result}; use ndarray::{Array2, IxDyn}; use std::collections::HashMap; use std::path::Path; use crate::model::OnnxSession; use super::{Vocoder, snake_activation_vec}; /// BigVGAN configuration #[derive(Debug, Clone)] pub struct BigVGANConfig { /// Sample rate pub sample_rate: u32, /// Number of mel channels pub num_mels: usize, /// Upsampling rates pub upsample_rates: Vec, /// Upsampling kernel sizes pub upsample_kernel_sizes: Vec, /// ResBlock kernel sizes pub resblock_kernel_sizes: Vec, /// ResBlock dilation sizes pub resblock_dilation_sizes: Vec>, /// Initial channel size pub upsample_initial_channel: usize, /// Use anti-aliasing pub use_anti_alias: bool, } impl Default for BigVGANConfig { fn default() -> Self { Self { sample_rate: 22050, num_mels: 80, upsample_rates: vec![8, 8, 2, 2], upsample_kernel_sizes: vec![16, 16, 4, 4], resblock_kernel_sizes: vec![3, 7, 11], resblock_dilation_sizes: vec![vec![1, 3, 5], vec![1, 3, 5], vec![1, 3, 5]], upsample_initial_channel: 512, use_anti_alias: true, } } } impl BigVGANConfig { /// Calculate total upsampling factor pub fn total_upsample_factor(&self) -> usize { self.upsample_rates.iter().product() } /// Get hop length (same as upsample factor) pub fn hop_length(&self) -> usize { self.total_upsample_factor() } } /// BigVGAN vocoder pub struct BigVGAN { session: Option, config: BigVGANConfig, } impl BigVGAN { /// Load BigVGAN from ONNX model pub fn load>(path: P, config: BigVGANConfig) -> Result { let session = OnnxSession::load(path)?; Ok(Self { session: Some(session), config, }) } /// Create BigVGAN with fallback synthesizer pub fn new_fallback(config: BigVGANConfig) -> Self { Self { session: None, config, } } /// Get configuration pub fn config(&self) -> &BigVGANConfig { &self.config } /// Synthesize audio using fallback algorithm fn synthesize_fallback(&self, mel: &Array2) -> Result> { // Simple overlap-add synthesis as fallback let num_frames = mel.ncols(); let hop_length = self.config.hop_length(); let frame_size = hop_length * 4; // Use 4x overlap let output_length = (num_frames - 1) * hop_length + frame_size; let mut output = vec![0.0f32; output_length]; let mut window_sum = vec![0.0f32; output_length]; // Hann window let window: Vec = (0..frame_size) .map(|n| { 0.5 * (1.0 - (2.0 * std::f32::consts::PI * n as f32 / frame_size as f32).cos()) }) .collect(); // Generate frames from mel for frame_idx in 0..num_frames { let start = frame_idx * hop_length; // Generate frame from mel (simplified: use mel features to modulate noise) let mel_frame: Vec = (0..self.config.num_mels) .map(|i| mel[[i, frame_idx]]) .collect(); // Generate frame using mel features let frame = self.generate_frame(&mel_frame, frame_size); // Overlap-add for i in 0..frame_size { if start + i < output_length { output[start + i] += frame[i] * window[i]; window_sum[start + i] += window[i] * window[i]; } } } // Normalize by window sum for i in 0..output_length { if window_sum[i] > 1e-8 { output[i] /= window_sum[i]; } } // Apply post-processing let output = snake_activation_vec(&output, 0.3); Ok(output) } /// Generate a single frame from mel features fn generate_frame(&self, mel: &[f32], frame_size: usize) -> Vec { use rand::Rng; let mut rng = rand::thread_rng(); // Compute overall energy from mel let energy: f32 = mel.iter().map(|x| x.exp()).sum::() / mel.len() as f32; let energy = energy.sqrt().min(2.0); // Generate frame with harmonic content let mut frame = vec![0.0f32; frame_size]; // Use mel bands to create frequency content for (freq_idx, &mel_val) in mel.iter().enumerate() { let freq = (freq_idx as f32 / mel.len() as f32) * (self.config.sample_rate as f32 / 2.0); let amplitude = mel_val.exp().min(1.0) * 0.1; // Add harmonic for i in 0..frame_size { let t = i as f32 / self.config.sample_rate as f32; frame[i] += amplitude * (2.0 * std::f32::consts::PI * freq * t).sin(); } } // Add filtered noise for i in 0..frame_size { frame[i] += rng.gen_range(-0.1..0.1) * energy * 0.1; } // Normalize let max_abs = frame.iter().map(|x| x.abs()).fold(0.0f32, f32::max); if max_abs > 1.0 { for v in frame.iter_mut() { *v /= max_abs; } } frame } /// Apply post-processing to output pub fn post_process(&self, audio: &[f32]) -> Vec { use crate::audio::{normalize_audio, apply_fade}; let normalized = normalize_audio(audio); // Apply fade to avoid clicks let fade_samples = (self.config.sample_rate as f32 * 0.01) as usize; // 10ms fade apply_fade(&normalized, fade_samples, fade_samples) } } impl Vocoder for BigVGAN { fn synthesize(&self, mel: &Array2) -> Result> { if let Some(ref session) = self.session { // Use ONNX model let input = mel.clone().into_shape(IxDyn(&[1, mel.nrows(), mel.ncols()]))?; let mut inputs = HashMap::new(); inputs.insert("mel".to_string(), input); let outputs = session.run(inputs)?; let audio = outputs .get("audio") .ok_or_else(|| Error::Vocoder("Missing audio output".into()))?; // Extract audio samples let samples: Vec = audio.iter().cloned().collect(); Ok(self.post_process(&samples)) } else { // Use fallback synthesis let audio = self.synthesize_fallback(mel)?; Ok(self.post_process(&audio)) } } fn sample_rate(&self) -> u32 { self.config.sample_rate } fn hop_length(&self) -> usize { self.config.hop_length() } } /// Helper function to create BigVGAN for 22kHz audio pub fn create_bigvgan_22k() -> BigVGAN { let config = BigVGANConfig { sample_rate: 22050, ..Default::default() }; BigVGAN::new_fallback(config) } /// Helper function to create BigVGAN for 24kHz audio pub fn create_bigvgan_24k() -> BigVGAN { let config = BigVGANConfig { sample_rate: 24000, upsample_rates: vec![12, 10, 2, 2], ..Default::default() }; BigVGAN::new_fallback(config) } #[cfg(test)] mod tests { use super::*; #[test] fn test_bigvgan_config() { let config = BigVGANConfig::default(); assert_eq!(config.total_upsample_factor(), 256); assert_eq!(config.hop_length(), 256); } #[test] fn test_bigvgan_fallback() { let vocoder = create_bigvgan_22k(); assert_eq!(vocoder.sample_rate(), 22050); // Create small test mel let mel = Array2::zeros((80, 10)); let result = vocoder.synthesize(&mel); assert!(result.is_ok()); let audio = result.unwrap(); assert!(audio.len() > 0); } #[test] fn test_generate_frame() { let vocoder = create_bigvgan_22k(); let mel = vec![0.0f32; 80]; let frame = vocoder.generate_frame(&mel, 256); assert_eq!(frame.len(), 256); } #[test] fn test_post_process() { let vocoder = create_bigvgan_22k(); let audio = vec![0.5f32; 1000]; let processed = vocoder.post_process(&audio); assert_eq!(processed.len(), audio.len()); // Check fade was applied (first samples should be smaller) assert!(processed[0].abs() < 0.1); } }