From 517c58bfef1a6e6511025c458ad0ba6b37ec066a Mon Sep 17 00:00:00 2001 From: Palash Tyagi <23239946+Magnus167@users.noreply.github.com> Date: Thu, 17 Apr 2025 11:58:53 +0100 Subject: [PATCH] add unit tests for Matrix and f64-specific helpers --- tests/mat_tests.rs | 124 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 124 insertions(+) create mode 100644 tests/mat_tests.rs diff --git a/tests/mat_tests.rs b/tests/mat_tests.rs new file mode 100644 index 0000000..7ac9c6b --- /dev/null +++ b/tests/mat_tests.rs @@ -0,0 +1,124 @@ +// Unit tests for Matrix and its f64‑specific helpers. + +#[cfg(test)] +mod tests { + use rustframe::frame::mat::{Axis, Matrix}; + use std::f64::NAN; + + const EPS: f64 = 1e-12; + + fn assert_vec_f64_eq(a: &[f64], b: &[f64]) { + assert_eq!(a.len(), b.len(), "length mismatch"); + for (i, (x, y)) in a.iter().zip(b.iter()).enumerate() { + assert!((*x - *y).abs() < EPS, "index {i} differs: {x} vs {y}"); + } + } + + // Constructors + + #[test] + fn from_cols_basic() { + let m = Matrix::from_cols(vec![vec![1, 2], vec![3, 4]]); + assert_eq!(m.rows(), 2); + assert_eq!(m.cols(), 2); + assert_eq!(m[(0, 0)], 1); + assert_eq!(m[(1, 0)], 2); + assert_eq!(m[(0, 1)], 3); + assert_eq!(m[(1, 1)], 4); + } + + #[test] + #[should_panic(expected = "need at least one column")] + fn from_cols_zero_columns_panics() { + let _ = Matrix::::from_cols(vec![]); + } + + #[test] + #[should_panic(expected = "col 1")] + fn from_cols_mismatched_lengths_panics() { + let _ = Matrix::from_cols(vec![vec![1, 2], vec![3]]); + } + + #[test] + fn from_vec_basic() { + // column‑major order: (r,c) = value + // (0,0)=1 (1,0)=2 (0,1)=3 (1,1)=4 (0,2)=5 (1,2)=6 + let data = vec![1, 2, 3, 4, 5, 6]; + let m = Matrix::from_vec(data, 2, 3); + assert_eq!(m[(0, 2)], 5); + assert_eq!(m[(1, 2)], 6); + } + + // Indexing & mutation + + #[test] + fn index_mut_works() { + let mut m = Matrix::from_cols(vec![vec![1, 2]]); + *m.get_mut(0, 0) = 10; + assert_eq!(m[(0, 0)], 10); + } + + #[test] + #[should_panic(expected = "index out of bounds")] + fn index_out_of_bounds_panics() { + let m = Matrix::from_cols(vec![vec![1]]); + let _ = m[(1, 0)]; + } + + // Column swapping + + #[test] + fn swap_columns_basic() { + let mut m = Matrix::from_cols(vec![vec![1, 2], vec![3, 4]]); + m.swap_columns(0, 1); + let expected = Matrix::from_cols(vec![vec![3, 4], vec![1, 2]]); + assert_eq!(m, expected); + } + + // ── Element‑wise ops ─ + + #[test] + fn elementwise_add_sub_mul_div() { + let a = Matrix::from_cols(vec![vec![1, 2], vec![3, 4]]); + let b = Matrix::from_cols(vec![vec![10, 20], vec![30, 40]]); + + assert_eq!(&a + &b, Matrix::from_cols(vec![vec![11, 22], vec![33, 44]])); + assert_eq!(&b - &a, Matrix::from_cols(vec![vec![9, 18], vec![27, 36]])); + assert_eq!( + &a * &b, + Matrix::from_cols(vec![vec![10, 40], vec![90, 160]]) + ); + assert_eq!(&b / &a, Matrix::from_cols(vec![vec![10, 10], vec![10, 10]])); + } + + // f64‑specific helpers + + #[test] + fn reductions_with_nan() { + let m = Matrix::from_cols(vec![vec![1.0, NAN, 3.0], vec![4.0, 5.0, NAN]]); + // Matrix: 3 rows × 2 columns + assert_vec_f64_eq(&m.sum_vertical(), &[4.0, 9.0]); + assert_vec_f64_eq(&m.sum_horizontal(), &[5.0, 5.0, 3.0]); + assert_vec_f64_eq(&m.prod_vertical(), &[3.0, 20.0]); + assert_vec_f64_eq(&m.prod_horizontal(), &[4.0, 5.0, 3.0]); + + assert_eq!(m.count_nan_vertical(), vec![1, 1]); + assert_eq!(m.count_nan_horizontal(), vec![0, 1, 1]); + } + + #[test] + fn apply_axis_dispatch() { + let m = Matrix::from_cols(vec![vec![1.0, 2.0], vec![3.0, 4.0]]); + let col_sums = m.apply_axis(Axis::Col, |col| col.iter().sum::()); + let row_sums = m.apply_axis(Axis::Row, |row| row.iter().sum::()); + assert_vec_f64_eq(&col_sums, &[3.0, 7.0]); + assert_vec_f64_eq(&row_sums, &[4.0, 6.0]); + } + + #[test] + fn is_nan_mask() { + let m = Matrix::from_cols(vec![vec![1.0, NAN]]); + let expected = Matrix::from_cols(vec![vec![false, true]]); + assert_eq!(m.is_nan(), expected); + } +}