From 285147d52b9fab888bf8f0c09d7cbee40f3e8e27 Mon Sep 17 00:00:00 2001 From: Palash Tyagi <23239946+Magnus167@users.noreply.github.com> Date: Tue, 15 Jul 2025 01:00:03 +0100 Subject: [PATCH] Refactor variance functions to distinguish between population and sample variance --- src/compute/stats/descriptive.rs | 53 ++++++++++++++++++++++---------- 1 file changed, 37 insertions(+), 16 deletions(-) diff --git a/src/compute/stats/descriptive.rs b/src/compute/stats/descriptive.rs index 7dedee8..d52ab62 100644 --- a/src/compute/stats/descriptive.rs +++ b/src/compute/stats/descriptive.rs @@ -14,17 +14,29 @@ pub fn mean_horizontal(x: &Matrix) -> Matrix { Matrix::from_vec(x.sum_horizontal(), x.rows(), 1) / n } -pub fn variance(x: &Matrix) -> f64 { +fn population_or_sample_variance(x: &Matrix, population: bool) -> f64 { let m = (x.rows() * x.cols()) as f64; let mean_val = mean(x); x.data() .iter() .map(|&v| (v - mean_val).powi(2)) .sum::() - / m + / if population { m } else { m - 1.0 } } -fn _variance_axis(x: &Matrix, axis: Axis) -> Matrix { +pub fn population_variance(x: &Matrix) -> f64 { + population_or_sample_variance(x, true) +} + +pub fn sample_variance(x: &Matrix) -> f64 { + population_or_sample_variance(x, false) +} + +fn _population_or_sample_variance_axis( + x: &Matrix, + axis: Axis, + population: bool, +) -> Matrix { match axis { Axis::Row => { // Calculate variance for each column (vertical variance) @@ -39,7 +51,7 @@ fn _variance_axis(x: &Matrix, axis: Axis) -> Matrix { let diff = x.get(r, c) - mean_val; sum_sq_diff += diff * diff; } - result_data[c] = sum_sq_diff / num_rows; + result_data[c] = sum_sq_diff / (if population { num_rows } else { num_rows - 1.0 }); } Matrix::from_vec(result_data, 1, x.cols()) } @@ -56,30 +68,39 @@ fn _variance_axis(x: &Matrix, axis: Axis) -> Matrix { let diff = x.get(r, c) - mean_val; sum_sq_diff += diff * diff; } - result_data[r] = sum_sq_diff / num_cols; + result_data[r] = sum_sq_diff / (if population { num_cols } else { num_cols - 1.0 }); } Matrix::from_vec(result_data, x.rows(), 1) } } } -pub fn variance_vertical(x: &Matrix) -> Matrix { - _variance_axis(x, Axis::Row) +pub fn population_variance_vertical(x: &Matrix) -> Matrix { + _population_or_sample_variance_axis(x, Axis::Row, true) } -pub fn variance_horizontal(x: &Matrix) -> Matrix { - _variance_axis(x, Axis::Col) + +pub fn population_variance_horizontal(x: &Matrix) -> Matrix { + _population_or_sample_variance_axis(x, Axis::Col, true) +} + +pub fn sample_variance_vertical(x: &Matrix) -> Matrix { + _population_or_sample_variance_axis(x, Axis::Row, false) +} + +pub fn sample_variance_horizontal(x: &Matrix) -> Matrix { + _population_or_sample_variance_axis(x, Axis::Col, false) } pub fn stddev(x: &Matrix) -> f64 { - variance(x).sqrt() + population_variance(x).sqrt() } pub fn stddev_vertical(x: &Matrix) -> Matrix { - variance_vertical(x).map(|v| v.sqrt()) + population_variance_vertical(x).map(|v| v.sqrt()) } pub fn stddev_horizontal(x: &Matrix) -> Matrix { - variance_horizontal(x).map(|v| v.sqrt()) + population_variance_horizontal(x).map(|v| v.sqrt()) } pub fn median(x: &Matrix) -> f64 { @@ -180,7 +201,7 @@ mod tests { assert!((mean(&x) - 3.0).abs() < EPSILON); // Variance - assert!((variance(&x) - 2.0).abs() < EPSILON); + assert!((population_variance(&x) - 2.0).abs() < EPSILON); // Standard Deviation assert!((stddev(&x) - 1.4142135623730951).abs() < EPSILON); @@ -209,7 +230,7 @@ mod tests { assert!((mean(&x) - 22.0).abs() < EPSILON); // Variance should be heavily affected by outlier - assert!((variance(&x) - 1522.0).abs() < EPSILON); + assert!((population_variance(&x) - 1522.0).abs() < EPSILON); // Standard Deviation should be heavily affected by outlier assert!((stddev(&x) - 39.0128183970461).abs() < EPSILON); @@ -258,12 +279,12 @@ mod tests { let x = Matrix::from_vec(data, 2, 3); // cols: {1,4}, {2,5}, {3,6} all give 2.25 - let vv = variance_vertical(&x); + let vv = population_variance_vertical(&x); for c in 0..3 { assert!((vv.get(0, c) - 2.25).abs() < EPSILON); } - let vh = variance_horizontal(&x); + let vh = population_variance_horizontal(&x); assert!((vh.get(0, 0) - (2.0 / 3.0)).abs() < EPSILON); assert!((vh.get(1, 0) - (2.0 / 3.0)).abs() < EPSILON); }