|
|
|
|
|
|
|
|
|
|
|
|
|
|
use crate::{Error, Result}; |
|
|
use ndarray::{Array2, IxDyn}; |
|
|
use std::collections::HashMap; |
|
|
use std::path::Path; |
|
|
|
|
|
use crate::model::OnnxSession; |
|
|
use super::{Vocoder, snake_activation_vec}; |
|
|
|
|
|
|
|
|
#[derive(Debug, Clone)] |
|
|
pub struct BigVGANConfig { |
|
|
|
|
|
pub sample_rate: u32, |
|
|
|
|
|
pub num_mels: usize, |
|
|
|
|
|
pub upsample_rates: Vec<usize>, |
|
|
|
|
|
pub upsample_kernel_sizes: Vec<usize>, |
|
|
|
|
|
pub resblock_kernel_sizes: Vec<usize>, |
|
|
|
|
|
pub resblock_dilation_sizes: Vec<Vec<usize>>, |
|
|
|
|
|
pub upsample_initial_channel: usize, |
|
|
|
|
|
pub use_anti_alias: bool, |
|
|
} |
|
|
|
|
|
impl Default for BigVGANConfig { |
|
|
fn default() -> Self { |
|
|
Self { |
|
|
sample_rate: 22050, |
|
|
num_mels: 80, |
|
|
upsample_rates: vec![8, 8, 2, 2], |
|
|
upsample_kernel_sizes: vec![16, 16, 4, 4], |
|
|
resblock_kernel_sizes: vec![3, 7, 11], |
|
|
resblock_dilation_sizes: vec![vec![1, 3, 5], vec![1, 3, 5], vec![1, 3, 5]], |
|
|
upsample_initial_channel: 512, |
|
|
use_anti_alias: true, |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
impl BigVGANConfig { |
|
|
|
|
|
pub fn total_upsample_factor(&self) -> usize { |
|
|
self.upsample_rates.iter().product() |
|
|
} |
|
|
|
|
|
|
|
|
pub fn hop_length(&self) -> usize { |
|
|
self.total_upsample_factor() |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
pub struct BigVGAN { |
|
|
session: Option<OnnxSession>, |
|
|
config: BigVGANConfig, |
|
|
} |
|
|
|
|
|
impl BigVGAN { |
|
|
|
|
|
pub fn load<P: AsRef<Path>>(path: P, config: BigVGANConfig) -> Result<Self> { |
|
|
let session = OnnxSession::load(path)?; |
|
|
Ok(Self { |
|
|
session: Some(session), |
|
|
config, |
|
|
}) |
|
|
} |
|
|
|
|
|
|
|
|
pub fn new_fallback(config: BigVGANConfig) -> Self { |
|
|
Self { |
|
|
session: None, |
|
|
config, |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
pub fn config(&self) -> &BigVGANConfig { |
|
|
&self.config |
|
|
} |
|
|
|
|
|
|
|
|
fn synthesize_fallback(&self, mel: &Array2<f32>) -> Result<Vec<f32>> { |
|
|
|
|
|
let num_frames = mel.ncols(); |
|
|
let hop_length = self.config.hop_length(); |
|
|
let frame_size = hop_length * 4; |
|
|
|
|
|
let output_length = (num_frames - 1) * hop_length + frame_size; |
|
|
let mut output = vec![0.0f32; output_length]; |
|
|
let mut window_sum = vec![0.0f32; output_length]; |
|
|
|
|
|
|
|
|
let window: Vec<f32> = (0..frame_size) |
|
|
.map(|n| { |
|
|
0.5 * (1.0 - (2.0 * std::f32::consts::PI * n as f32 / frame_size as f32).cos()) |
|
|
}) |
|
|
.collect(); |
|
|
|
|
|
|
|
|
for frame_idx in 0..num_frames { |
|
|
let start = frame_idx * hop_length; |
|
|
|
|
|
|
|
|
let mel_frame: Vec<f32> = (0..self.config.num_mels) |
|
|
.map(|i| mel[[i, frame_idx]]) |
|
|
.collect(); |
|
|
|
|
|
|
|
|
let frame = self.generate_frame(&mel_frame, frame_size); |
|
|
|
|
|
|
|
|
for i in 0..frame_size { |
|
|
if start + i < output_length { |
|
|
output[start + i] += frame[i] * window[i]; |
|
|
window_sum[start + i] += window[i] * window[i]; |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
for i in 0..output_length { |
|
|
if window_sum[i] > 1e-8 { |
|
|
output[i] /= window_sum[i]; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
let output = snake_activation_vec(&output, 0.3); |
|
|
|
|
|
Ok(output) |
|
|
} |
|
|
|
|
|
|
|
|
fn generate_frame(&self, mel: &[f32], frame_size: usize) -> Vec<f32> { |
|
|
use rand::Rng; |
|
|
let mut rng = rand::thread_rng(); |
|
|
|
|
|
|
|
|
let energy: f32 = mel.iter().map(|x| x.exp()).sum::<f32>() / mel.len() as f32; |
|
|
let energy = energy.sqrt().min(2.0); |
|
|
|
|
|
|
|
|
let mut frame = vec![0.0f32; frame_size]; |
|
|
|
|
|
|
|
|
for (freq_idx, &mel_val) in mel.iter().enumerate() { |
|
|
let freq = (freq_idx as f32 / mel.len() as f32) * (self.config.sample_rate as f32 / 2.0); |
|
|
let amplitude = mel_val.exp().min(1.0) * 0.1; |
|
|
|
|
|
|
|
|
for i in 0..frame_size { |
|
|
let t = i as f32 / self.config.sample_rate as f32; |
|
|
frame[i] += amplitude * (2.0 * std::f32::consts::PI * freq * t).sin(); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
for i in 0..frame_size { |
|
|
frame[i] += rng.gen_range(-0.1..0.1) * energy * 0.1; |
|
|
} |
|
|
|
|
|
|
|
|
let max_abs = frame.iter().map(|x| x.abs()).fold(0.0f32, f32::max); |
|
|
if max_abs > 1.0 { |
|
|
for v in frame.iter_mut() { |
|
|
*v /= max_abs; |
|
|
} |
|
|
} |
|
|
|
|
|
frame |
|
|
} |
|
|
|
|
|
|
|
|
pub fn post_process(&self, audio: &[f32]) -> Vec<f32> { |
|
|
use crate::audio::{normalize_audio, apply_fade}; |
|
|
|
|
|
let normalized = normalize_audio(audio); |
|
|
|
|
|
|
|
|
let fade_samples = (self.config.sample_rate as f32 * 0.01) as usize; |
|
|
apply_fade(&normalized, fade_samples, fade_samples) |
|
|
} |
|
|
} |
|
|
|
|
|
impl Vocoder for BigVGAN { |
|
|
fn synthesize(&self, mel: &Array2<f32>) -> Result<Vec<f32>> { |
|
|
if let Some(ref session) = self.session { |
|
|
|
|
|
let input = mel.clone().into_shape(IxDyn(&[1, mel.nrows(), mel.ncols()]))?; |
|
|
|
|
|
let mut inputs = HashMap::new(); |
|
|
inputs.insert("mel".to_string(), input); |
|
|
|
|
|
let outputs = session.run(inputs)?; |
|
|
|
|
|
let audio = outputs |
|
|
.get("audio") |
|
|
.ok_or_else(|| Error::Vocoder("Missing audio output".into()))?; |
|
|
|
|
|
|
|
|
let samples: Vec<f32> = audio.iter().cloned().collect(); |
|
|
|
|
|
Ok(self.post_process(&samples)) |
|
|
} else { |
|
|
|
|
|
let audio = self.synthesize_fallback(mel)?; |
|
|
Ok(self.post_process(&audio)) |
|
|
} |
|
|
} |
|
|
|
|
|
fn sample_rate(&self) -> u32 { |
|
|
self.config.sample_rate |
|
|
} |
|
|
|
|
|
fn hop_length(&self) -> usize { |
|
|
self.config.hop_length() |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
pub fn create_bigvgan_22k() -> BigVGAN { |
|
|
let config = BigVGANConfig { |
|
|
sample_rate: 22050, |
|
|
..Default::default() |
|
|
}; |
|
|
BigVGAN::new_fallback(config) |
|
|
} |
|
|
|
|
|
|
|
|
pub fn create_bigvgan_24k() -> BigVGAN { |
|
|
let config = BigVGANConfig { |
|
|
sample_rate: 24000, |
|
|
upsample_rates: vec![12, 10, 2, 2], |
|
|
..Default::default() |
|
|
}; |
|
|
BigVGAN::new_fallback(config) |
|
|
} |
|
|
|
|
|
#[cfg(test)] |
|
|
mod tests { |
|
|
use super::*; |
|
|
|
|
|
#[test] |
|
|
fn test_bigvgan_config() { |
|
|
let config = BigVGANConfig::default(); |
|
|
assert_eq!(config.total_upsample_factor(), 256); |
|
|
assert_eq!(config.hop_length(), 256); |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_bigvgan_fallback() { |
|
|
let vocoder = create_bigvgan_22k(); |
|
|
assert_eq!(vocoder.sample_rate(), 22050); |
|
|
|
|
|
|
|
|
let mel = Array2::zeros((80, 10)); |
|
|
let result = vocoder.synthesize(&mel); |
|
|
assert!(result.is_ok()); |
|
|
|
|
|
let audio = result.unwrap(); |
|
|
assert!(audio.len() > 0); |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_generate_frame() { |
|
|
let vocoder = create_bigvgan_22k(); |
|
|
let mel = vec![0.0f32; 80]; |
|
|
let frame = vocoder.generate_frame(&mel, 256); |
|
|
assert_eq!(frame.len(), 256); |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_post_process() { |
|
|
let vocoder = create_bigvgan_22k(); |
|
|
let audio = vec![0.5f32; 1000]; |
|
|
let processed = vocoder.post_process(&audio); |
|
|
assert_eq!(processed.len(), audio.len()); |
|
|
|
|
|
assert!(processed[0].abs() < 0.1); |
|
|
} |
|
|
} |
|
|
|