diff --git a/src/compute/stats/descriptive.rs b/src/compute/stats/descriptive.rs index 0be00bb..a9e8245 100644 --- a/src/compute/stats/descriptive.rs +++ b/src/compute/stats/descriptive.rs @@ -26,7 +26,8 @@ pub fn variance(x: &Matrix) -> f64 { fn _variance_axis(x: &Matrix, axis: Axis) -> Matrix { match axis { - Axis::Row => { // Calculate variance for each column (vertical variance) + Axis::Row => { + // Calculate variance for each column (vertical variance) let num_rows = x.rows() as f64; let mean_of_cols = mean_vertical(x); // 1 x cols matrix let mut result_data = vec![0.0; x.cols()]; @@ -42,7 +43,8 @@ fn _variance_axis(x: &Matrix, axis: Axis) -> Matrix { } Matrix::from_vec(result_data, 1, x.cols()) } - Axis::Col => { // Calculate variance for each row (horizontal variance) + Axis::Col => { + // Calculate variance for each row (horizontal variance) let num_cols = x.cols() as f64; let mean_of_rows = mean_horizontal(x); // rows x 1 matrix let mut result_data = vec![0.0; x.rows()]; @@ -224,4 +226,58 @@ mod tests { let x = Matrix::from_vec(data, 1, 3); percentile(&x, 101.0); } + + #[test] + fn test_mean_vertical_horizontal() { + // 2x3 matrix: + let data = vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]; + let x = Matrix::from_vec(data, 2, 3); + + // Vertical means (per column): [(1+4)/2, (2+5)/2, (3+6)/2] + let mv = mean_vertical(&x); + assert!((mv.get(0, 0) - 2.5).abs() < EPSILON); + assert!((mv.get(0, 1) - 3.5).abs() < EPSILON); + assert!((mv.get(0, 2) - 4.5).abs() < EPSILON); + + // Horizontal means (per row): [(1+2+3)/3, (4+5+6)/3] + let mh = mean_horizontal(&x); + assert!((mh.get(0, 0) - 2.0).abs() < EPSILON); + assert!((mh.get(1, 0) - 5.0).abs() < EPSILON); + } + + #[test] + fn test_variance_vertical_horizontal() { + let data = vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]; + let x = Matrix::from_vec(data, 2, 3); + + // Vertical variances (per column): each is ((v - mean)^2 summed / 2) + // cols: {1,4}, {2,5}, {3,6} all give 2.25 + let vv = variance_vertical(&x); + for c in 0..3 { + assert!((vv.get(0, c) - 2.25).abs() < EPSILON); + } + + // Horizontal variances (per row): rows [1,2,3] and [4,5,6] both give 2/3 + let vh = variance_horizontal(&x); + assert!((vh.get(0, 0) - (2.0 / 3.0)).abs() < EPSILON); + assert!((vh.get(1, 0) - (2.0 / 3.0)).abs() < EPSILON); + } + + #[test] + fn test_stddev_vertical_horizontal() { + let data = vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]; + let x = Matrix::from_vec(data, 2, 3); + + // Stddev is sqrt of variance + let sv = stddev_vertical(&x); + for c in 0..3 { + assert!((sv.get(0, c) - 1.5).abs() < EPSILON); + } + + let sh = stddev_horizontal(&x); + // sqrt(2/3) ≈ 0.816497 + let expected = (2.0 / 3.0 as f64).sqrt(); + assert!((sh.get(0, 0) - expected).abs() < EPSILON); + assert!((sh.get(1, 0) - expected).abs() < EPSILON); + } }