|
|
|
|
|
|
|
|
|
|
|
|
|
|
use std::f32::consts::PI; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pub fn snake_activation(x: f32, alpha: f32) -> f32 { |
|
|
let sin_val = (alpha * x).sin(); |
|
|
x + sin_val * sin_val / alpha |
|
|
} |
|
|
|
|
|
|
|
|
pub fn snake_activation_vec(x: &[f32], alpha: f32) -> Vec<f32> { |
|
|
x.iter().map(|&v| snake_activation(v, alpha)).collect() |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pub fn snake_beta_activation(x: f32, alpha: f32, beta: f32) -> f32 { |
|
|
let sin_val = (alpha * x).sin(); |
|
|
x + sin_val * sin_val / beta |
|
|
} |
|
|
|
|
|
|
|
|
pub fn snake_beta_activation_vec(x: &[f32], alpha: f32, beta: f32) -> Vec<f32> { |
|
|
x.iter() |
|
|
.map(|&v| snake_beta_activation(v, alpha, beta)) |
|
|
.collect() |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pub fn anti_aliased_snake(x: &[f32], alpha: f32, upsample_factor: usize) -> Vec<f32> { |
|
|
|
|
|
let upsampled: Vec<f32> = x |
|
|
.iter() |
|
|
.flat_map(|&v| std::iter::repeat_n(v, upsample_factor)) |
|
|
.collect(); |
|
|
|
|
|
|
|
|
let activated: Vec<f32> = upsampled |
|
|
.iter() |
|
|
.map(|&v| snake_activation(v, alpha)) |
|
|
.collect(); |
|
|
|
|
|
|
|
|
activated |
|
|
.chunks(upsample_factor) |
|
|
.map(|chunk| chunk.iter().sum::<f32>() / chunk.len() as f32) |
|
|
.collect() |
|
|
} |
|
|
|
|
|
|
|
|
pub fn leaky_relu(x: f32, negative_slope: f32) -> f32 { |
|
|
if x >= 0.0 { |
|
|
x |
|
|
} else { |
|
|
negative_slope * x |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
pub fn leaky_relu_vec(x: &[f32], negative_slope: f32) -> Vec<f32> { |
|
|
x.iter().map(|&v| leaky_relu(v, negative_slope)).collect() |
|
|
} |
|
|
|
|
|
|
|
|
pub fn gelu(x: f32) -> f32 { |
|
|
0.5 * x * (1.0 + ((2.0 / PI).sqrt() * (x + 0.044715 * x * x * x)).tanh()) |
|
|
} |
|
|
|
|
|
|
|
|
pub fn gelu_vec(x: &[f32]) -> Vec<f32> { |
|
|
x.iter().map(|&v| gelu(v)).collect() |
|
|
} |
|
|
|
|
|
|
|
|
pub fn swish(x: f32) -> f32 { |
|
|
x / (1.0 + (-x).exp()) |
|
|
} |
|
|
|
|
|
|
|
|
pub fn swish_vec(x: &[f32]) -> Vec<f32> { |
|
|
x.iter().map(|&v| swish(v)).collect() |
|
|
} |
|
|
|
|
|
|
|
|
pub fn mish(x: f32) -> f32 { |
|
|
x * ((1.0 + x.exp()).ln()).tanh() |
|
|
} |
|
|
|
|
|
|
|
|
pub fn mish_vec(x: &[f32]) -> Vec<f32> { |
|
|
x.iter().map(|&v| mish(v)).collect() |
|
|
} |
|
|
|
|
|
#[cfg(test)] |
|
|
mod tests { |
|
|
use super::*; |
|
|
|
|
|
#[test] |
|
|
fn test_snake_activation() { |
|
|
let result = snake_activation(0.0, 1.0); |
|
|
assert!((result - 0.0).abs() < 1e-6); |
|
|
|
|
|
let result = snake_activation(1.0, 1.0); |
|
|
assert!(result > 1.0); |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_snake_beta_activation() { |
|
|
let result = snake_beta_activation(0.0, 1.0, 1.0); |
|
|
assert!((result - 0.0).abs() < 1e-6); |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_leaky_relu() { |
|
|
assert_eq!(leaky_relu(1.0, 0.01), 1.0); |
|
|
assert_eq!(leaky_relu(-1.0, 0.01), -0.01); |
|
|
assert_eq!(leaky_relu(0.0, 0.01), 0.0); |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_gelu() { |
|
|
let result = gelu(0.0); |
|
|
assert!((result - 0.0).abs() < 1e-6); |
|
|
|
|
|
let result = gelu(1.0); |
|
|
assert!(result > 0.5 && result < 1.0); |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_swish() { |
|
|
let result = swish(0.0); |
|
|
assert!((result - 0.0).abs() < 1e-6); |
|
|
|
|
|
let result = swish(1.0); |
|
|
assert!(result > 0.5 && result < 1.0); |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_anti_aliased_snake() { |
|
|
let input = vec![0.0, 1.0, 2.0, 3.0]; |
|
|
let result = anti_aliased_snake(&input, 1.0, 2); |
|
|
assert_eq!(result.len(), input.len()); |
|
|
} |
|
|
} |
|
|
|