diff --git a/src/compute/activations.rs b/src/compute/activations.rs index f25c00c..d2f3710 100644 --- a/src/compute/activations.rs +++ b/src/compute/activations.rs @@ -17,6 +17,13 @@ pub fn drelu(x: &Matrix) -> Matrix { x.map(|v| if v > 0.0 { 1.0 } else { 0.0 }) } +pub fn leaky_relu(x: &Matrix) -> Matrix { + x.map(|v| if v > 0.0 { v } else { 0.01 * v }) +} + +pub fn dleaky_relu(x: &Matrix) -> Matrix { + x.map(|v| if v > 0.0 { 1.0 } else { 0.01 }) +} mod tests { use super::*; @@ -57,4 +64,17 @@ mod tests { let expected = Matrix::from_vec(vec![0.0, 0.0, 1.0], 3, 1); assert_eq!(drelu(&x), expected); } -} \ No newline at end of file + #[test] + fn test_leaky_relu() { + let x = Matrix::from_vec(vec![-1.0, 0.0, 1.0], 3, 1); + let expected = Matrix::from_vec(vec![-0.01, 0.0, 1.0], 3, 1); + assert_eq!(leaky_relu(&x), expected); + } + #[test] + fn test_dleaky_relu() { + let x = Matrix::from_vec(vec![-1.0, 0.0, 1.0], 3, 1); + let expected = Matrix::from_vec(vec![0.01, 0.01, 1.0], 3, 1); + assert_eq!(dleaky_relu(&x), expected); + } + +}