Add unit tests for activation functions: sigmoid, relu, dsigmoid, and drelu

This commit is contained in:
Palash Tyagi 2025-07-06 17:51:43 +01:00
parent b279131503
commit 37b20f2174

View File

@ -16,3 +16,45 @@ pub fn relu(x: &Matrix<f64>) -> Matrix<f64> {
pub fn drelu(x: &Matrix<f64>) -> Matrix<f64> { pub fn drelu(x: &Matrix<f64>) -> Matrix<f64> {
x.map(|v| if v > 0.0 { 1.0 } else { 0.0 }) x.map(|v| if v > 0.0 { 1.0 } else { 0.0 })
} }
mod tests {
use super::*;
// Helper function to round all elements in a matrix to n decimal places
fn _round_matrix(mat: &Matrix<f64>, decimals: u32) -> Matrix<f64> {
let factor = 10f64.powi(decimals as i32);
let rounded: Vec<f64> = mat.to_vec().iter().map(|v| (v * factor).round() / factor).collect();
Matrix::from_vec(rounded, mat.rows(), mat.cols())
}
#[test]
fn test_sigmoid() {
let x = Matrix::from_vec(vec![-1.0, 0.0, 1.0], 3, 1);
let expected = Matrix::from_vec(vec![0.26894142, 0.5, 0.73105858], 3, 1);
let result = sigmoid(&x);
assert_eq!(_round_matrix(&result, 6), _round_matrix(&expected, 6));
}
#[test]
fn test_relu() {
let x = Matrix::from_vec(vec![-1.0, 0.0, 1.0], 3, 1);
let expected = Matrix::from_vec(vec![0.0, 0.0, 1.0], 3, 1);
assert_eq!(relu(&x), expected);
}
#[test]
fn test_dsigmoid() {
let y = Matrix::from_vec(vec![0.26894142, 0.5, 0.73105858], 3, 1);
let expected = Matrix::from_vec(vec![0.19661193, 0.25, 0.19661193], 3, 1);
let result = dsigmoid(&y);
assert_eq!(_round_matrix(&result, 6), _round_matrix(&expected, 6));
}
#[test]
fn test_drelu() {
let x = Matrix::from_vec(vec![-1.0, 0.0, 1.0], 3, 1);
let expected = Matrix::from_vec(vec![0.0, 0.0, 1.0], 3, 1);
assert_eq!(drelu(&x), expected);
}
}