Add leaky_relu and dleaky_relu functions with corresponding unit tests

This commit is contained in:
Palash Tyagi 2025-07-06 20:00:17 +01:00
parent ab6d5f9f8f
commit 4c626bf09c

View File

@ -17,6 +17,13 @@ pub fn drelu(x: &Matrix<f64>) -> Matrix<f64> {
x.map(|v| if v > 0.0 { 1.0 } else { 0.0 }) x.map(|v| if v > 0.0 { 1.0 } else { 0.0 })
} }
pub fn leaky_relu(x: &Matrix<f64>) -> Matrix<f64> {
x.map(|v| if v > 0.0 { v } else { 0.01 * v })
}
pub fn dleaky_relu(x: &Matrix<f64>) -> Matrix<f64> {
x.map(|v| if v > 0.0 { 1.0 } else { 0.01 })
}
mod tests { mod tests {
use super::*; use super::*;
@ -57,4 +64,17 @@ mod tests {
let expected = Matrix::from_vec(vec![0.0, 0.0, 1.0], 3, 1); let expected = Matrix::from_vec(vec![0.0, 0.0, 1.0], 3, 1);
assert_eq!(drelu(&x), expected); assert_eq!(drelu(&x), expected);
} }
} #[test]
fn test_leaky_relu() {
let x = Matrix::from_vec(vec![-1.0, 0.0, 1.0], 3, 1);
let expected = Matrix::from_vec(vec![-0.01, 0.0, 1.0], 3, 1);
assert_eq!(leaky_relu(&x), expected);
}
#[test]
fn test_dleaky_relu() {
let x = Matrix::from_vec(vec![-1.0, 0.0, 1.0], 3, 1);
let expected = Matrix::from_vec(vec![0.01, 0.01, 1.0], 3, 1);
assert_eq!(dleaky_relu(&x), expected);
}
}