File size: 3,785 Bytes
2bbfbb7 0393dfa 2bbfbb7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
//! Activation functions for BigVGAN
//!
//! Includes Snake and SnakeBeta activations
use std::f32::consts::PI;
/// Snake activation function
///
/// x + (1/alpha) * sin^2(alpha * x)
pub fn snake_activation(x: f32, alpha: f32) -> f32 {
let sin_val = (alpha * x).sin();
x + sin_val * sin_val / alpha
}
/// Snake activation for vector
pub fn snake_activation_vec(x: &[f32], alpha: f32) -> Vec<f32> {
x.iter().map(|&v| snake_activation(v, alpha)).collect()
}
/// Snake Beta activation function
///
/// x + (1/beta) * sin^2(alpha * x)
pub fn snake_beta_activation(x: f32, alpha: f32, beta: f32) -> f32 {
let sin_val = (alpha * x).sin();
x + sin_val * sin_val / beta
}
/// Snake Beta activation for vector
pub fn snake_beta_activation_vec(x: &[f32], alpha: f32, beta: f32) -> Vec<f32> {
x.iter()
.map(|&v| snake_beta_activation(v, alpha, beta))
.collect()
}
/// Anti-aliased Snake activation
///
/// Uses lowpass filtering to reduce aliasing artifacts
pub fn anti_aliased_snake(x: &[f32], alpha: f32, upsample_factor: usize) -> Vec<f32> {
// Upsample
let upsampled: Vec<f32> = x
.iter()
.flat_map(|&v| std::iter::repeat_n(v, upsample_factor))
.collect();
// Apply activation
let activated: Vec<f32> = upsampled
.iter()
.map(|&v| snake_activation(v, alpha))
.collect();
// Downsample (simple averaging)
activated
.chunks(upsample_factor)
.map(|chunk| chunk.iter().sum::<f32>() / chunk.len() as f32)
.collect()
}
/// Leaky ReLU activation
pub fn leaky_relu(x: f32, negative_slope: f32) -> f32 {
if x >= 0.0 {
x
} else {
negative_slope * x
}
}
/// Leaky ReLU for vector
pub fn leaky_relu_vec(x: &[f32], negative_slope: f32) -> Vec<f32> {
x.iter().map(|&v| leaky_relu(v, negative_slope)).collect()
}
/// GELU (Gaussian Error Linear Unit) activation
pub fn gelu(x: f32) -> f32 {
0.5 * x * (1.0 + ((2.0 / PI).sqrt() * (x + 0.044715 * x * x * x)).tanh())
}
/// GELU for vector
pub fn gelu_vec(x: &[f32]) -> Vec<f32> {
x.iter().map(|&v| gelu(v)).collect()
}
/// Swish activation (SiLU)
pub fn swish(x: f32) -> f32 {
x / (1.0 + (-x).exp())
}
/// Swish for vector
pub fn swish_vec(x: &[f32]) -> Vec<f32> {
x.iter().map(|&v| swish(v)).collect()
}
/// Mish activation
pub fn mish(x: f32) -> f32 {
x * ((1.0 + x.exp()).ln()).tanh()
}
/// Mish for vector
pub fn mish_vec(x: &[f32]) -> Vec<f32> {
x.iter().map(|&v| mish(v)).collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_snake_activation() {
let result = snake_activation(0.0, 1.0);
assert!((result - 0.0).abs() < 1e-6);
let result = snake_activation(1.0, 1.0);
assert!(result > 1.0); // Should add positive value
}
#[test]
fn test_snake_beta_activation() {
let result = snake_beta_activation(0.0, 1.0, 1.0);
assert!((result - 0.0).abs() < 1e-6);
}
#[test]
fn test_leaky_relu() {
assert_eq!(leaky_relu(1.0, 0.01), 1.0);
assert_eq!(leaky_relu(-1.0, 0.01), -0.01);
assert_eq!(leaky_relu(0.0, 0.01), 0.0);
}
#[test]
fn test_gelu() {
let result = gelu(0.0);
assert!((result - 0.0).abs() < 1e-6);
let result = gelu(1.0);
assert!(result > 0.5 && result < 1.0);
}
#[test]
fn test_swish() {
let result = swish(0.0);
assert!((result - 0.0).abs() < 1e-6);
let result = swish(1.0);
assert!(result > 0.5 && result < 1.0);
}
#[test]
fn test_anti_aliased_snake() {
let input = vec![0.0, 1.0, 2.0, 3.0];
let result = anti_aliased_snake(&input, 1.0, 2);
assert_eq!(result.len(), input.len());
}
}
|