mirror of
https://github.com/Magnus167/rustframe.git
synced 2025-08-20 04:00:01 +00:00
Refactor variance functions to distinguish between population and sample variance
This commit is contained in:
parent
64722914bd
commit
285147d52b
@ -14,17 +14,29 @@ pub fn mean_horizontal(x: &Matrix<f64>) -> Matrix<f64> {
|
|||||||
Matrix::from_vec(x.sum_horizontal(), x.rows(), 1) / n
|
Matrix::from_vec(x.sum_horizontal(), x.rows(), 1) / n
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn variance(x: &Matrix<f64>) -> f64 {
|
fn population_or_sample_variance(x: &Matrix<f64>, population: bool) -> f64 {
|
||||||
let m = (x.rows() * x.cols()) as f64;
|
let m = (x.rows() * x.cols()) as f64;
|
||||||
let mean_val = mean(x);
|
let mean_val = mean(x);
|
||||||
x.data()
|
x.data()
|
||||||
.iter()
|
.iter()
|
||||||
.map(|&v| (v - mean_val).powi(2))
|
.map(|&v| (v - mean_val).powi(2))
|
||||||
.sum::<f64>()
|
.sum::<f64>()
|
||||||
/ m
|
/ if population { m } else { m - 1.0 }
|
||||||
}
|
}
|
||||||
|
|
||||||
fn _variance_axis(x: &Matrix<f64>, axis: Axis) -> Matrix<f64> {
|
pub fn population_variance(x: &Matrix<f64>) -> f64 {
|
||||||
|
population_or_sample_variance(x, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn sample_variance(x: &Matrix<f64>) -> f64 {
|
||||||
|
population_or_sample_variance(x, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn _population_or_sample_variance_axis(
|
||||||
|
x: &Matrix<f64>,
|
||||||
|
axis: Axis,
|
||||||
|
population: bool,
|
||||||
|
) -> Matrix<f64> {
|
||||||
match axis {
|
match axis {
|
||||||
Axis::Row => {
|
Axis::Row => {
|
||||||
// Calculate variance for each column (vertical variance)
|
// Calculate variance for each column (vertical variance)
|
||||||
@ -39,7 +51,7 @@ fn _variance_axis(x: &Matrix<f64>, axis: Axis) -> Matrix<f64> {
|
|||||||
let diff = x.get(r, c) - mean_val;
|
let diff = x.get(r, c) - mean_val;
|
||||||
sum_sq_diff += diff * diff;
|
sum_sq_diff += diff * diff;
|
||||||
}
|
}
|
||||||
result_data[c] = sum_sq_diff / num_rows;
|
result_data[c] = sum_sq_diff / (if population { num_rows } else { num_rows - 1.0 });
|
||||||
}
|
}
|
||||||
Matrix::from_vec(result_data, 1, x.cols())
|
Matrix::from_vec(result_data, 1, x.cols())
|
||||||
}
|
}
|
||||||
@ -56,30 +68,39 @@ fn _variance_axis(x: &Matrix<f64>, axis: Axis) -> Matrix<f64> {
|
|||||||
let diff = x.get(r, c) - mean_val;
|
let diff = x.get(r, c) - mean_val;
|
||||||
sum_sq_diff += diff * diff;
|
sum_sq_diff += diff * diff;
|
||||||
}
|
}
|
||||||
result_data[r] = sum_sq_diff / num_cols;
|
result_data[r] = sum_sq_diff / (if population { num_cols } else { num_cols - 1.0 });
|
||||||
}
|
}
|
||||||
Matrix::from_vec(result_data, x.rows(), 1)
|
Matrix::from_vec(result_data, x.rows(), 1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn variance_vertical(x: &Matrix<f64>) -> Matrix<f64> {
|
pub fn population_variance_vertical(x: &Matrix<f64>) -> Matrix<f64> {
|
||||||
_variance_axis(x, Axis::Row)
|
_population_or_sample_variance_axis(x, Axis::Row, true)
|
||||||
}
|
}
|
||||||
pub fn variance_horizontal(x: &Matrix<f64>) -> Matrix<f64> {
|
|
||||||
_variance_axis(x, Axis::Col)
|
pub fn population_variance_horizontal(x: &Matrix<f64>) -> Matrix<f64> {
|
||||||
|
_population_or_sample_variance_axis(x, Axis::Col, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn sample_variance_vertical(x: &Matrix<f64>) -> Matrix<f64> {
|
||||||
|
_population_or_sample_variance_axis(x, Axis::Row, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn sample_variance_horizontal(x: &Matrix<f64>) -> Matrix<f64> {
|
||||||
|
_population_or_sample_variance_axis(x, Axis::Col, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn stddev(x: &Matrix<f64>) -> f64 {
|
pub fn stddev(x: &Matrix<f64>) -> f64 {
|
||||||
variance(x).sqrt()
|
population_variance(x).sqrt()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn stddev_vertical(x: &Matrix<f64>) -> Matrix<f64> {
|
pub fn stddev_vertical(x: &Matrix<f64>) -> Matrix<f64> {
|
||||||
variance_vertical(x).map(|v| v.sqrt())
|
population_variance_vertical(x).map(|v| v.sqrt())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn stddev_horizontal(x: &Matrix<f64>) -> Matrix<f64> {
|
pub fn stddev_horizontal(x: &Matrix<f64>) -> Matrix<f64> {
|
||||||
variance_horizontal(x).map(|v| v.sqrt())
|
population_variance_horizontal(x).map(|v| v.sqrt())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn median(x: &Matrix<f64>) -> f64 {
|
pub fn median(x: &Matrix<f64>) -> f64 {
|
||||||
@ -180,7 +201,7 @@ mod tests {
|
|||||||
assert!((mean(&x) - 3.0).abs() < EPSILON);
|
assert!((mean(&x) - 3.0).abs() < EPSILON);
|
||||||
|
|
||||||
// Variance
|
// Variance
|
||||||
assert!((variance(&x) - 2.0).abs() < EPSILON);
|
assert!((population_variance(&x) - 2.0).abs() < EPSILON);
|
||||||
|
|
||||||
// Standard Deviation
|
// Standard Deviation
|
||||||
assert!((stddev(&x) - 1.4142135623730951).abs() < EPSILON);
|
assert!((stddev(&x) - 1.4142135623730951).abs() < EPSILON);
|
||||||
@ -209,7 +230,7 @@ mod tests {
|
|||||||
assert!((mean(&x) - 22.0).abs() < EPSILON);
|
assert!((mean(&x) - 22.0).abs() < EPSILON);
|
||||||
|
|
||||||
// Variance should be heavily affected by outlier
|
// Variance should be heavily affected by outlier
|
||||||
assert!((variance(&x) - 1522.0).abs() < EPSILON);
|
assert!((population_variance(&x) - 1522.0).abs() < EPSILON);
|
||||||
|
|
||||||
// Standard Deviation should be heavily affected by outlier
|
// Standard Deviation should be heavily affected by outlier
|
||||||
assert!((stddev(&x) - 39.0128183970461).abs() < EPSILON);
|
assert!((stddev(&x) - 39.0128183970461).abs() < EPSILON);
|
||||||
@ -258,12 +279,12 @@ mod tests {
|
|||||||
let x = Matrix::from_vec(data, 2, 3);
|
let x = Matrix::from_vec(data, 2, 3);
|
||||||
|
|
||||||
// cols: {1,4}, {2,5}, {3,6} all give 2.25
|
// cols: {1,4}, {2,5}, {3,6} all give 2.25
|
||||||
let vv = variance_vertical(&x);
|
let vv = population_variance_vertical(&x);
|
||||||
for c in 0..3 {
|
for c in 0..3 {
|
||||||
assert!((vv.get(0, c) - 2.25).abs() < EPSILON);
|
assert!((vv.get(0, c) - 2.25).abs() < EPSILON);
|
||||||
}
|
}
|
||||||
|
|
||||||
let vh = variance_horizontal(&x);
|
let vh = population_variance_horizontal(&x);
|
||||||
assert!((vh.get(0, 0) - (2.0 / 3.0)).abs() < EPSILON);
|
assert!((vh.get(0, 0) - (2.0 / 3.0)).abs() < EPSILON);
|
||||||
assert!((vh.get(1, 0) - (2.0 / 3.0)).abs() < EPSILON);
|
assert!((vh.get(1, 0) - (2.0 / 3.0)).abs() < EPSILON);
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user