diff --git a/src/matrix/mat.rs b/src/matrix/mat.rs index ec61e66..f748ced 100644 --- a/src/matrix/mat.rs +++ b/src/matrix/mat.rs @@ -621,7 +621,242 @@ pub type StringMatrix = Matrix; mod tests { use crate::matrix::BoolOps; - use super::{BoolMatrix, FloatMatrix, Matrix, StringMatrix}; + use super::*; // Import items from outer scope + + // Helper to create a 2x2 f64 matrix easily + fn make_f64_matrix(a: f64, b: f64, c: f64, d: f64) -> FloatMatrix { + Matrix::from_cols(vec![vec![a, c], vec![b, d]]) + } + + // Helper to create a 2x2 bool matrix easily + fn make_bool_matrix(a: bool, b: bool, c: bool, d: bool) -> BoolMatrix { + Matrix::from_cols(vec![vec![a, c], vec![b, d]]) + } + + // --- Arithmetic Tests --- + + #[test] + fn test_add_f64() { + let m1 = make_f64_matrix(1.0, 2.0, 3.0, 4.0); + let m2 = make_f64_matrix(5.0, 6.0, 7.0, 8.0); + let expected = make_f64_matrix(6.0, 8.0, 10.0, 12.0); + + assert_eq!(m1.clone() + m2.clone(), expected, "M + M"); + assert_eq!(m1.clone() + &m2, expected, "M + &M"); + assert_eq!(&m1 + m2.clone(), expected, "&M + M"); + assert_eq!(&m1 + &m2, expected, "&M + &M"); + } + + #[test] + fn test_add_scalar_f64() { + let m1 = make_f64_matrix(1.0, 2.0, 3.0, 4.0); + let scalar = 10.0; + let expected = make_f64_matrix(11.0, 12.0, 13.0, 14.0); + + assert_eq!(m1.clone() + scalar, expected, "M + S"); + assert_eq!(&m1 + scalar, expected, "&M + S"); + } + + #[test] + fn test_sub_f64() { + let m1 = make_f64_matrix(10.0, 20.0, 30.0, 40.0); + let m2 = make_f64_matrix(1.0, 2.0, 3.0, 4.0); + let expected = make_f64_matrix(9.0, 18.0, 27.0, 36.0); + + assert_eq!(m1.clone() - m2.clone(), expected, "M - M"); + assert_eq!(m1.clone() - &m2, expected, "M - &M"); + assert_eq!(&m1 - m2.clone(), expected, "&M - M"); + assert_eq!(&m1 - &m2, expected, "&M - &M"); + } + + #[test] + fn test_sub_scalar_f64() { + let m1 = make_f64_matrix(11.0, 12.0, 13.0, 14.0); + let scalar = 10.0; + let expected = make_f64_matrix(1.0, 2.0, 3.0, 4.0); + + assert_eq!(m1.clone() - scalar, expected, "M - S"); + assert_eq!(&m1 - scalar, expected, "&M - S"); + } + + #[test] + fn test_mul_f64() { + // Element-wise + let m1 = make_f64_matrix(1.0, 2.0, 3.0, 4.0); + let m2 = make_f64_matrix(5.0, 6.0, 7.0, 8.0); + let expected = make_f64_matrix(5.0, 12.0, 21.0, 32.0); + + assert_eq!(m1.clone() * m2.clone(), expected, "M * M"); + assert_eq!(m1.clone() * &m2, expected, "M * &M"); + assert_eq!(&m1 * m2.clone(), expected, "&M * M"); + assert_eq!(&m1 * &m2, expected, "&M * &M"); + } + + #[test] + fn test_mul_scalar_f64() { + let m1 = make_f64_matrix(1.0, 2.0, 3.0, 4.0); + let scalar = 3.0; + let expected = make_f64_matrix(3.0, 6.0, 9.0, 12.0); + + assert_eq!(m1.clone() * scalar, expected, "M * S"); + assert_eq!(&m1 * scalar, expected, "&M * S"); + } + + #[test] + fn test_div_f64() { + // Element-wise + let m1 = make_f64_matrix(10.0, 20.0, 30.0, 40.0); + let m2 = make_f64_matrix(2.0, 5.0, 6.0, 8.0); + let expected = make_f64_matrix(5.0, 4.0, 5.0, 5.0); + + assert_eq!(m1.clone() / m2.clone(), expected, "M / M"); + assert_eq!(m1.clone() / &m2, expected, "M / &M"); + assert_eq!(&m1 / m2.clone(), expected, "&M / M"); + assert_eq!(&m1 / &m2, expected, "&M / &M"); + } + + #[test] + fn test_div_scalar_f64() { + let m1 = make_f64_matrix(10.0, 20.0, 30.0, 40.0); + let scalar = 10.0; + let expected = make_f64_matrix(1.0, 2.0, 3.0, 4.0); + + assert_eq!(m1.clone() / scalar, expected, "M / S"); + assert_eq!(&m1 / scalar, expected, "&M / S"); + } + + #[test] + fn test_chained_ops_f64() { + let m = make_f64_matrix(1.0, 2.0, 3.0, 4.0); + let result = (((m.clone() + 1.0) * 2.0) - 4.0) / 2.0; + // Expected: + // m+1 = [2, 3], [4, 5] + // *2 = [4, 6], [8, 10] + // -4 = [0, 2], [4, 6] + // /2 = [0, 1], [2, 3] + let expected = make_f64_matrix(0.0, 1.0, 2.0, 3.0); + assert_eq!(result, expected); + } + + // --- Boolean Logic Tests --- + + #[test] + fn test_bitand_bool() { + let m1 = make_bool_matrix(true, false, true, false); + let m2 = make_bool_matrix(true, true, false, false); + let expected = make_bool_matrix(true, false, false, false); + + assert_eq!(m1.clone() & m2.clone(), expected, "M & M"); + assert_eq!(m1.clone() & &m2, expected, "M & &M"); + assert_eq!(&m1 & m2.clone(), expected, "&M & M"); + assert_eq!(&m1 & &m2, expected, "&M & &M"); + } + + #[test] + fn test_bitor_bool() { + let m1 = make_bool_matrix(true, false, true, false); + let m2 = make_bool_matrix(true, true, false, false); + let expected = make_bool_matrix(true, true, true, false); + + assert_eq!(m1.clone() | m2.clone(), expected, "M | M"); + assert_eq!(m1.clone() | &m2, expected, "M | &M"); + assert_eq!(&m1 | m2.clone(), expected, "&M | M"); + assert_eq!(&m1 | &m2, expected, "&M | &M"); + } + + #[test] + fn test_bitxor_bool() { + let m1 = make_bool_matrix(true, false, true, false); + let m2 = make_bool_matrix(true, true, false, false); + let expected = make_bool_matrix(false, true, true, false); + + assert_eq!(m1.clone() ^ m2.clone(), expected, "M ^ M"); + assert_eq!(m1.clone() ^ &m2, expected, "M ^ &M"); + assert_eq!(&m1 ^ m2.clone(), expected, "&M ^ M"); + assert_eq!(&m1 ^ &m2, expected, "&M ^ &M"); + } + + #[test] + fn test_not_bool() { + let m = make_bool_matrix(true, false, true, false); + let expected = make_bool_matrix(false, true, false, true); + + assert_eq!(!m.clone(), expected, "!M (consuming)"); + assert_eq!(!&m, expected, "!&M (borrowing)"); + + // Check original is unchanged when using !&M + let original = make_bool_matrix(true, false, true, false); + let _negated_ref = !&original; + assert_eq!(original, make_bool_matrix(true, false, true, false)); + } + + // --- Comparison Tests --- + #[test] + fn test_comparison_eq_elem() { + let m1 = make_f64_matrix(1.0, 2.0, 3.0, 4.0); + let m2 = make_f64_matrix(1.0, 0.0, 3.0, 5.0); + let s = 3.0; + let expected_m = make_bool_matrix(true, false, true, false); + let expected_s = make_bool_matrix(false, false, true, false); + + assert_eq!(m1.eq_elem(m2), expected_m, "eq_elem matrix"); + assert_eq!(m1.eq_elem(s), expected_s, "eq_elem scalar"); + } + + #[test] + fn test_comparison_gt_elem() { + let m1 = make_f64_matrix(1.0, 2.0, 3.0, 4.0); + let m2 = make_f64_matrix(0.0, 3.0, 3.0, 5.0); + let s = 2.5; + let expected_m = make_bool_matrix(true, false, false, false); + let expected_s = make_bool_matrix(false, false, true, true); + + assert_eq!(m1.gt_elem(m2), expected_m, "gt_elem matrix"); + assert_eq!(m1.gt_elem(s), expected_s, "gt_elem scalar"); + } + + // Add more comparison tests (lt, le, ge, ne) if desired... + + // --- Basic Method Tests --- + #[test] + fn test_indexing() { + let m = make_f64_matrix(1.0, 2.0, 3.0, 4.0); + assert_eq!(m[(0, 0)], 1.0); + assert_eq!(m[(0, 1)], 2.0); + assert_eq!(m[(1, 0)], 3.0); + assert_eq!(m[(1, 1)], 4.0); + assert_eq!(*m.get(1, 0), 3.0); // Test get() too + } + + #[test] + #[should_panic] + fn test_index_out_of_bounds_row() { + let m = make_f64_matrix(1.0, 2.0, 3.0, 4.0); + let _ = m[(2, 0)]; // Row 2 is out of bounds + } + + #[test] + #[should_panic] + fn test_index_out_of_bounds_col() { + let m = make_f64_matrix(1.0, 2.0, 3.0, 4.0); + let _ = m[(0, 2)]; // Col 2 is out of bounds + } + + #[test] + fn test_dimensions() { + let m = make_f64_matrix(1.0, 2.0, 3.0, 4.0); + assert_eq!(m.rows(), 2); + assert_eq!(m.cols(), 2); + } + + #[test] + fn test_from_vec() { + let data = vec![1.0, 3.0, 2.0, 4.0]; // Column major: [col0_row0, col0_row1, col1_row0, col1_row1] + let m = Matrix::from_vec(data, 2, 2); + let expected = make_f64_matrix(1.0, 2.0, 3.0, 4.0); + assert_eq!(m, expected); + assert_eq!(m.to_vec(), vec![1.0, 3.0, 2.0, 4.0]); + } // Helper function to create a basic Matrix for testing fn static_test_matrix() -> Matrix {