mirror of
https://github.com/Magnus167/rustframe.git
synced 2025-08-20 04:00:01 +00:00
Refactor median and percentile functions to handle vertical and horizontal calculations correctly; add corresponding tests for validation
This commit is contained in:
parent
a2fcaf1d52
commit
5779c6b82d
@ -94,33 +94,35 @@ pub fn median(x: &Matrix<f64>) -> f64 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn _median_axis(x: &Matrix<f64>, axis: Axis) -> Matrix<f64> {
|
fn _median_axis(x: &Matrix<f64>, axis: Axis) -> Matrix<f64> {
|
||||||
let mut data = match axis {
|
let mx = match axis {
|
||||||
Axis::Row => x.sum_vertical(),
|
Axis::Col => x.clone(),
|
||||||
Axis::Col => x.sum_horizontal(),
|
Axis::Row => x.transpose(),
|
||||||
};
|
};
|
||||||
data.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
|
||||||
let mid = data.len() / 2;
|
let mut result = Vec::with_capacity(mx.cols());
|
||||||
if data.len() % 2 == 0 {
|
for c in 0..mx.cols() {
|
||||||
Matrix::from_vec(
|
let mut col = mx.column(c).to_vec();
|
||||||
vec![(data[mid - 1] + data[mid]) / 2.0],
|
col.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
||||||
if axis == Axis::Row { 1 } else { x.rows() },
|
let mid = col.len() / 2;
|
||||||
if axis == Axis::Row { x.cols() } else { 1 },
|
if col.len() % 2 == 0 {
|
||||||
)
|
result.push((col[mid - 1] + col[mid]) / 2.0);
|
||||||
} else {
|
} else {
|
||||||
Matrix::from_vec(
|
result.push(col[mid]);
|
||||||
vec![data[mid]],
|
|
||||||
if axis == Axis::Row { 1 } else { x.rows() },
|
|
||||||
if axis == Axis::Row { x.cols() } else { 1 },
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
let (r, c) = match axis {
|
||||||
|
Axis::Col => (1, mx.cols()),
|
||||||
|
Axis::Row => (mx.cols(), 1),
|
||||||
|
};
|
||||||
|
Matrix::from_vec(result, r, c)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn median_vertical(x: &Matrix<f64>) -> Matrix<f64> {
|
pub fn median_vertical(x: &Matrix<f64>) -> Matrix<f64> {
|
||||||
_median_axis(x, Axis::Row)
|
_median_axis(x, Axis::Col)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn median_horizontal(x: &Matrix<f64>) -> Matrix<f64> {
|
pub fn median_horizontal(x: &Matrix<f64>) -> Matrix<f64> {
|
||||||
_median_axis(x, Axis::Col)
|
_median_axis(x, Axis::Row)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn percentile(x: &Matrix<f64>, p: f64) -> f64 {
|
pub fn percentile(x: &Matrix<f64>, p: f64) -> f64 {
|
||||||
@ -137,24 +139,29 @@ fn _percentile_axis(x: &Matrix<f64>, p: f64, axis: Axis) -> Matrix<f64> {
|
|||||||
if p < 0.0 || p > 100.0 {
|
if p < 0.0 || p > 100.0 {
|
||||||
panic!("Percentile must be between 0 and 100");
|
panic!("Percentile must be between 0 and 100");
|
||||||
}
|
}
|
||||||
let mut data = match axis {
|
let mx: Matrix<f64> = match axis {
|
||||||
Axis::Row => x.sum_vertical(),
|
Axis::Col => x.clone(),
|
||||||
Axis::Col => x.sum_horizontal(),
|
Axis::Row => x.transpose(),
|
||||||
};
|
};
|
||||||
data.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
let mut result = Vec::with_capacity(mx.cols());
|
||||||
let index = ((p / 100.0) * (data.len() as f64 - 1.0)).round() as usize;
|
for c in 0..mx.cols() {
|
||||||
Matrix::from_vec(
|
let mut col = mx.column(c).to_vec();
|
||||||
vec![data[index]],
|
col.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
||||||
if axis == Axis::Row { 1 } else { x.rows() },
|
let index = ((p / 100.0) * (col.len() as f64 - 1.0)).round() as usize;
|
||||||
if axis == Axis::Row { x.cols() } else { 1 },
|
result.push(col[index]);
|
||||||
)
|
}
|
||||||
|
let (r, c) = match axis {
|
||||||
|
Axis::Col => (1, mx.cols()),
|
||||||
|
Axis::Row => (mx.cols(), 1),
|
||||||
|
};
|
||||||
|
Matrix::from_vec(result, r, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn percentile_vertical(x: &Matrix<f64>, p: f64) -> Matrix<f64> {
|
pub fn percentile_vertical(x: &Matrix<f64>, p: f64) -> Matrix<f64> {
|
||||||
_percentile_axis(x, p, Axis::Row)
|
_percentile_axis(x, p, Axis::Col)
|
||||||
}
|
}
|
||||||
pub fn percentile_horizontal(x: &Matrix<f64>, p: f64) -> Matrix<f64> {
|
pub fn percentile_horizontal(x: &Matrix<f64>, p: f64) -> Matrix<f64> {
|
||||||
_percentile_axis(x, p, Axis::Col)
|
_percentile_axis(x, p, Axis::Row)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@ -250,14 +257,12 @@ mod tests {
|
|||||||
let data = vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0];
|
let data = vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0];
|
||||||
let x = Matrix::from_vec(data, 2, 3);
|
let x = Matrix::from_vec(data, 2, 3);
|
||||||
|
|
||||||
// Vertical variances (per column): each is ((v - mean)^2 summed / 2)
|
|
||||||
// cols: {1,4}, {2,5}, {3,6} all give 2.25
|
// cols: {1,4}, {2,5}, {3,6} all give 2.25
|
||||||
let vv = variance_vertical(&x);
|
let vv = variance_vertical(&x);
|
||||||
for c in 0..3 {
|
for c in 0..3 {
|
||||||
assert!((vv.get(0, c) - 2.25).abs() < EPSILON);
|
assert!((vv.get(0, c) - 2.25).abs() < EPSILON);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Horizontal variances (per row): rows [1,2,3] and [4,5,6] both give 2/3
|
|
||||||
let vh = variance_horizontal(&x);
|
let vh = variance_horizontal(&x);
|
||||||
assert!((vh.get(0, 0) - (2.0 / 3.0)).abs() < EPSILON);
|
assert!((vh.get(0, 0) - (2.0 / 3.0)).abs() < EPSILON);
|
||||||
assert!((vh.get(1, 0) - (2.0 / 3.0)).abs() < EPSILON);
|
assert!((vh.get(1, 0) - (2.0 / 3.0)).abs() < EPSILON);
|
||||||
@ -280,4 +285,59 @@ mod tests {
|
|||||||
assert!((sh.get(0, 0) - expected).abs() < EPSILON);
|
assert!((sh.get(0, 0) - expected).abs() < EPSILON);
|
||||||
assert!((sh.get(1, 0) - expected).abs() < EPSILON);
|
assert!((sh.get(1, 0) - expected).abs() < EPSILON);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_median_vertical_horizontal() {
|
||||||
|
let data = vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0];
|
||||||
|
let x = Matrix::from_vec(data, 2, 3);
|
||||||
|
|
||||||
|
let mv = median_vertical(&x).row(0);
|
||||||
|
|
||||||
|
let expected_v = vec![2.5, 3.5, 4.5];
|
||||||
|
assert_eq!(mv, expected_v, "{:?} expected: {:?}", expected_v, mv);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_percentile_vertical_horizontal() {
|
||||||
|
// vec of f64 values 1..24 as a 4x6 matrix
|
||||||
|
let data: Vec<f64> = (1..=24).map(|x| x as f64).collect();
|
||||||
|
let x = Matrix::from_vec(data, 4, 6);
|
||||||
|
|
||||||
|
// columns:
|
||||||
|
// 1, 5, 9, 13, 17, 21
|
||||||
|
// 2, 6, 10, 14, 18, 22
|
||||||
|
// 3, 7, 11, 15, 19, 23
|
||||||
|
// 4, 8, 12, 16, 20, 24
|
||||||
|
|
||||||
|
let er0 = vec![1., 5., 9., 13., 17., 21.];
|
||||||
|
let er50 = vec![3., 7., 11., 15., 19., 23.];
|
||||||
|
let er100 = vec![4., 8., 12., 16., 20., 24.];
|
||||||
|
|
||||||
|
assert_eq!(percentile_vertical(&x, 0.0).data(), er0);
|
||||||
|
assert_eq!(percentile_vertical(&x, 50.0).data(), er50);
|
||||||
|
assert_eq!(percentile_vertical(&x, 100.0).data(), er100);
|
||||||
|
|
||||||
|
let eh0 = vec![1., 2., 3., 4.];
|
||||||
|
let eh50 = vec![13., 14., 15., 16.];
|
||||||
|
let eh100 = vec![21., 22., 23., 24.];
|
||||||
|
|
||||||
|
assert_eq!(percentile_horizontal(&x, 0.0).data(), eh0);
|
||||||
|
assert_eq!(percentile_horizontal(&x, 50.0).data(), eh50);
|
||||||
|
assert_eq!(percentile_horizontal(&x, 100.0).data(), eh100);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[should_panic(expected = "Percentile must be between 0 and 100")]
|
||||||
|
fn test_percentile_out_of_bounds() {
|
||||||
|
let data = vec![1.0, 2.0, 3.0];
|
||||||
|
let x = Matrix::from_vec(data, 1, 3);
|
||||||
|
percentile(&x, -10.0); // Should panic
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[should_panic(expected = "Percentile must be between 0 and 100")]
|
||||||
|
fn test_percentile_vertical_out_of_bounds() {
|
||||||
|
let m = Matrix::from_vec(vec![1.0, 2.0, 3.0], 1, 3);
|
||||||
|
let _ = percentile_vertical(&m, -0.1);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user