Add tests for mean, variance, and standard deviation calculations in vertical and horizontal directions

This commit is contained in:
Palash Tyagi 2025-07-07 23:36:43 +01:00
parent 6711cad6e2
commit a2fcaf1d52

View File

@ -26,7 +26,8 @@ pub fn variance(x: &Matrix<f64>) -> f64 {
fn _variance_axis(x: &Matrix<f64>, axis: Axis) -> Matrix<f64> { fn _variance_axis(x: &Matrix<f64>, axis: Axis) -> Matrix<f64> {
match axis { match axis {
Axis::Row => { // Calculate variance for each column (vertical variance) Axis::Row => {
// Calculate variance for each column (vertical variance)
let num_rows = x.rows() as f64; let num_rows = x.rows() as f64;
let mean_of_cols = mean_vertical(x); // 1 x cols matrix let mean_of_cols = mean_vertical(x); // 1 x cols matrix
let mut result_data = vec![0.0; x.cols()]; let mut result_data = vec![0.0; x.cols()];
@ -42,7 +43,8 @@ fn _variance_axis(x: &Matrix<f64>, axis: Axis) -> Matrix<f64> {
} }
Matrix::from_vec(result_data, 1, x.cols()) Matrix::from_vec(result_data, 1, x.cols())
} }
Axis::Col => { // Calculate variance for each row (horizontal variance) Axis::Col => {
// Calculate variance for each row (horizontal variance)
let num_cols = x.cols() as f64; let num_cols = x.cols() as f64;
let mean_of_rows = mean_horizontal(x); // rows x 1 matrix let mean_of_rows = mean_horizontal(x); // rows x 1 matrix
let mut result_data = vec![0.0; x.rows()]; let mut result_data = vec![0.0; x.rows()];
@ -224,4 +226,58 @@ mod tests {
let x = Matrix::from_vec(data, 1, 3); let x = Matrix::from_vec(data, 1, 3);
percentile(&x, 101.0); percentile(&x, 101.0);
} }
#[test]
fn test_mean_vertical_horizontal() {
// 2x3 matrix:
let data = vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0];
let x = Matrix::from_vec(data, 2, 3);
// Vertical means (per column): [(1+4)/2, (2+5)/2, (3+6)/2]
let mv = mean_vertical(&x);
assert!((mv.get(0, 0) - 2.5).abs() < EPSILON);
assert!((mv.get(0, 1) - 3.5).abs() < EPSILON);
assert!((mv.get(0, 2) - 4.5).abs() < EPSILON);
// Horizontal means (per row): [(1+2+3)/3, (4+5+6)/3]
let mh = mean_horizontal(&x);
assert!((mh.get(0, 0) - 2.0).abs() < EPSILON);
assert!((mh.get(1, 0) - 5.0).abs() < EPSILON);
}
#[test]
fn test_variance_vertical_horizontal() {
let data = vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0];
let x = Matrix::from_vec(data, 2, 3);
// Vertical variances (per column): each is ((v - mean)^2 summed / 2)
// cols: {1,4}, {2,5}, {3,6} all give 2.25
let vv = variance_vertical(&x);
for c in 0..3 {
assert!((vv.get(0, c) - 2.25).abs() < EPSILON);
}
// Horizontal variances (per row): rows [1,2,3] and [4,5,6] both give 2/3
let vh = variance_horizontal(&x);
assert!((vh.get(0, 0) - (2.0 / 3.0)).abs() < EPSILON);
assert!((vh.get(1, 0) - (2.0 / 3.0)).abs() < EPSILON);
}
#[test]
fn test_stddev_vertical_horizontal() {
let data = vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0];
let x = Matrix::from_vec(data, 2, 3);
// Stddev is sqrt of variance
let sv = stddev_vertical(&x);
for c in 0..3 {
assert!((sv.get(0, c) - 1.5).abs() < EPSILON);
}
let sh = stddev_horizontal(&x);
// sqrt(2/3) ≈ 0.816497
let expected = (2.0 / 3.0 as f64).sqrt();
assert!((sh.get(0, 0) - expected).abs() < EPSILON);
assert!((sh.get(1, 0) - expected).abs() < EPSILON);
}
} }