diff --git a/src/compute/stats/descriptive.rs b/src/compute/stats/descriptive.rs index a9e8245..276e531 100644 --- a/src/compute/stats/descriptive.rs +++ b/src/compute/stats/descriptive.rs @@ -94,33 +94,35 @@ pub fn median(x: &Matrix) -> f64 { } fn _median_axis(x: &Matrix, axis: Axis) -> Matrix { - let mut data = match axis { - Axis::Row => x.sum_vertical(), - Axis::Col => x.sum_horizontal(), + let mx = match axis { + Axis::Col => x.clone(), + Axis::Row => x.transpose(), }; - data.sort_by(|a, b| a.partial_cmp(b).unwrap()); - let mid = data.len() / 2; - if data.len() % 2 == 0 { - Matrix::from_vec( - vec![(data[mid - 1] + data[mid]) / 2.0], - if axis == Axis::Row { 1 } else { x.rows() }, - if axis == Axis::Row { x.cols() } else { 1 }, - ) - } else { - Matrix::from_vec( - vec![data[mid]], - if axis == Axis::Row { 1 } else { x.rows() }, - if axis == Axis::Row { x.cols() } else { 1 }, - ) + + let mut result = Vec::with_capacity(mx.cols()); + for c in 0..mx.cols() { + let mut col = mx.column(c).to_vec(); + col.sort_by(|a, b| a.partial_cmp(b).unwrap()); + let mid = col.len() / 2; + if col.len() % 2 == 0 { + result.push((col[mid - 1] + col[mid]) / 2.0); + } else { + result.push(col[mid]); + } } + let (r, c) = match axis { + Axis::Col => (1, mx.cols()), + Axis::Row => (mx.cols(), 1), + }; + Matrix::from_vec(result, r, c) } pub fn median_vertical(x: &Matrix) -> Matrix { - _median_axis(x, Axis::Row) + _median_axis(x, Axis::Col) } pub fn median_horizontal(x: &Matrix) -> Matrix { - _median_axis(x, Axis::Col) + _median_axis(x, Axis::Row) } pub fn percentile(x: &Matrix, p: f64) -> f64 { @@ -137,24 +139,29 @@ fn _percentile_axis(x: &Matrix, p: f64, axis: Axis) -> Matrix { if p < 0.0 || p > 100.0 { panic!("Percentile must be between 0 and 100"); } - let mut data = match axis { - Axis::Row => x.sum_vertical(), - Axis::Col => x.sum_horizontal(), + let mx: Matrix = match axis { + Axis::Col => x.clone(), + Axis::Row => x.transpose(), }; - data.sort_by(|a, b| a.partial_cmp(b).unwrap()); - let index = ((p / 100.0) * (data.len() as f64 - 1.0)).round() as usize; - Matrix::from_vec( - vec![data[index]], - if axis == Axis::Row { 1 } else { x.rows() }, - if axis == Axis::Row { x.cols() } else { 1 }, - ) + let mut result = Vec::with_capacity(mx.cols()); + for c in 0..mx.cols() { + let mut col = mx.column(c).to_vec(); + col.sort_by(|a, b| a.partial_cmp(b).unwrap()); + let index = ((p / 100.0) * (col.len() as f64 - 1.0)).round() as usize; + result.push(col[index]); + } + let (r, c) = match axis { + Axis::Col => (1, mx.cols()), + Axis::Row => (mx.cols(), 1), + }; + Matrix::from_vec(result, r, c) } pub fn percentile_vertical(x: &Matrix, p: f64) -> Matrix { - _percentile_axis(x, p, Axis::Row) + _percentile_axis(x, p, Axis::Col) } pub fn percentile_horizontal(x: &Matrix, p: f64) -> Matrix { - _percentile_axis(x, p, Axis::Col) + _percentile_axis(x, p, Axis::Row) } #[cfg(test)] @@ -250,14 +257,12 @@ mod tests { let data = vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]; let x = Matrix::from_vec(data, 2, 3); - // Vertical variances (per column): each is ((v - mean)^2 summed / 2) // cols: {1,4}, {2,5}, {3,6} all give 2.25 let vv = variance_vertical(&x); for c in 0..3 { assert!((vv.get(0, c) - 2.25).abs() < EPSILON); } - // Horizontal variances (per row): rows [1,2,3] and [4,5,6] both give 2/3 let vh = variance_horizontal(&x); assert!((vh.get(0, 0) - (2.0 / 3.0)).abs() < EPSILON); assert!((vh.get(1, 0) - (2.0 / 3.0)).abs() < EPSILON); @@ -280,4 +285,59 @@ mod tests { assert!((sh.get(0, 0) - expected).abs() < EPSILON); assert!((sh.get(1, 0) - expected).abs() < EPSILON); } + + #[test] + fn test_median_vertical_horizontal() { + let data = vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]; + let x = Matrix::from_vec(data, 2, 3); + + let mv = median_vertical(&x).row(0); + + let expected_v = vec![2.5, 3.5, 4.5]; + assert_eq!(mv, expected_v, "{:?} expected: {:?}", expected_v, mv); + } + + #[test] + fn test_percentile_vertical_horizontal() { + // vec of f64 values 1..24 as a 4x6 matrix + let data: Vec = (1..=24).map(|x| x as f64).collect(); + let x = Matrix::from_vec(data, 4, 6); + + // columns: + // 1, 5, 9, 13, 17, 21 + // 2, 6, 10, 14, 18, 22 + // 3, 7, 11, 15, 19, 23 + // 4, 8, 12, 16, 20, 24 + + let er0 = vec![1., 5., 9., 13., 17., 21.]; + let er50 = vec![3., 7., 11., 15., 19., 23.]; + let er100 = vec![4., 8., 12., 16., 20., 24.]; + + assert_eq!(percentile_vertical(&x, 0.0).data(), er0); + assert_eq!(percentile_vertical(&x, 50.0).data(), er50); + assert_eq!(percentile_vertical(&x, 100.0).data(), er100); + + let eh0 = vec![1., 2., 3., 4.]; + let eh50 = vec![13., 14., 15., 16.]; + let eh100 = vec![21., 22., 23., 24.]; + + assert_eq!(percentile_horizontal(&x, 0.0).data(), eh0); + assert_eq!(percentile_horizontal(&x, 50.0).data(), eh50); + assert_eq!(percentile_horizontal(&x, 100.0).data(), eh100); + } + + #[test] + #[should_panic(expected = "Percentile must be between 0 and 100")] + fn test_percentile_out_of_bounds() { + let data = vec![1.0, 2.0, 3.0]; + let x = Matrix::from_vec(data, 1, 3); + percentile(&x, -10.0); // Should panic + } + + #[test] + #[should_panic(expected = "Percentile must be between 0 and 100")] + fn test_percentile_vertical_out_of_bounds() { + let m = Matrix::from_vec(vec![1.0, 2.0, 3.0], 1, 3); + let _ = percentile_vertical(&m, -0.1); + } }