From 1ccdee45ed6d371a6c7afae56919fefd31aabda4 Mon Sep 17 00:00:00 2001 From: Palash Tyagi <23239946+Magnus167@users.noreply.github.com> Date: Fri, 18 Apr 2025 22:39:38 +0100 Subject: [PATCH] implement tests for the matrix struct --- tests/mat_tests.rs | 928 +++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 847 insertions(+), 81 deletions(-) diff --git a/tests/mat_tests.rs b/tests/mat_tests.rs index 7ac9c6b..d850500 100644 --- a/tests/mat_tests.rs +++ b/tests/mat_tests.rs @@ -1,124 +1,890 @@ -// Unit tests for Matrix and its f64‑specific helpers. - #[cfg(test)] mod tests { - use rustframe::frame::mat::{Axis, Matrix}; - use std::f64::NAN; + use rustframe::matrix::{BoolMatrix, FloatMatrix, Matrix, StringMatrix}; - const EPS: f64 = 1e-12; - - fn assert_vec_f64_eq(a: &[f64], b: &[f64]) { - assert_eq!(a.len(), b.len(), "length mismatch"); - for (i, (x, y)) in a.iter().zip(b.iter()).enumerate() { - assert!((*x - *y).abs() < EPS, "index {i} differs: {x} vs {y}"); - } + // Helper function to create a basic Matrix for testing + fn create_test_matrix() -> Matrix { + // Column-major data: + // 1 4 7 + // 2 5 8 + // 3 6 9 + let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9]; + Matrix::from_vec(data, 3, 3) } - // Constructors + // Another helper for a different size + fn create_test_matrix_2x4() -> Matrix { + // Column-major data: + // 1 3 5 7 + // 2 4 6 8 + let data = vec![1, 2, 3, 4, 5, 6, 7, 8]; + Matrix::from_vec(data, 2, 4) + } #[test] - fn from_cols_basic() { - let m = Matrix::from_cols(vec![vec![1, 2], vec![3, 4]]); - assert_eq!(m.rows(), 2); - assert_eq!(m.cols(), 2); - assert_eq!(m[(0, 0)], 1); - assert_eq!(m[(1, 0)], 2); - assert_eq!(m[(0, 1)], 3); - assert_eq!(m[(1, 1)], 4); + fn test_from_vec_basic() { + let data = vec![1, 2, 3, 4, 5, 6]; // 2 rows, 3 cols (column-major) + let matrix = Matrix::from_vec(data, 2, 3); + + assert_eq!(matrix.rows(), 2); + assert_eq!(matrix.cols(), 3); + assert_eq!(matrix.data(), &[1, 2, 3, 4, 5, 6]); + + // Check some elements + assert_eq!(matrix[(0, 0)], 1); // First row, first col + assert_eq!(matrix[(1, 0)], 2); // Second row, first col + assert_eq!(matrix[(0, 1)], 3); // First row, second col + assert_eq!(matrix[(1, 2)], 6); // Second row, third col + } + + #[test] + #[should_panic(expected = "data length mismatch")] + fn test_from_vec_wrong_length() { + let data = vec![1, 2, 3, 4, 5]; // Should be 6 for 2x3 + Matrix::from_vec(data, 2, 3); + } + + #[test] + #[should_panic(expected = "need at least one row")] + fn test_from_vec_zero_rows() { + let data = vec![1, 2, 3]; + Matrix::from_vec(data, 0, 3); } #[test] #[should_panic(expected = "need at least one column")] - fn from_cols_zero_columns_panics() { - let _ = Matrix::::from_cols(vec![]); + fn test_from_vec_zero_cols() { + let data = vec![1, 2, 3]; + Matrix::from_vec(data, 3, 0); } #[test] - #[should_panic(expected = "col 1")] - fn from_cols_mismatched_lengths_panics() { - let _ = Matrix::from_cols(vec![vec![1, 2], vec![3]]); + fn test_from_cols_basic() { + // Representing: + // 1 4 7 + // 2 5 8 + // 3 6 9 + let cols_data = vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]]; + let matrix = Matrix::from_cols(cols_data); + + assert_eq!(matrix.rows(), 3); + assert_eq!(matrix.cols(), 3); + // Internal data should be column-major + assert_eq!(matrix.data(), &[1, 2, 3, 4, 5, 6, 7, 8, 9]); + + // Check some elements + assert_eq!(matrix[(0, 0)], 1); + assert_eq!(matrix[(2, 0)], 3); + assert_eq!(matrix[(1, 1)], 5); + assert_eq!(matrix[(0, 2)], 7); } #[test] - fn from_vec_basic() { - // column‑major order: (r,c) = value - // (0,0)=1 (1,0)=2 (0,1)=3 (1,1)=4 (0,2)=5 (1,2)=6 - let data = vec![1, 2, 3, 4, 5, 6]; - let m = Matrix::from_vec(data, 2, 3); - assert_eq!(m[(0, 2)], 5); - assert_eq!(m[(1, 2)], 6); + fn test_from_cols_1x1() { + let cols_data = vec![vec![42]]; + let matrix = Matrix::from_cols(cols_data); + assert_eq!(matrix.rows(), 1); + assert_eq!(matrix.cols(), 1); + assert_eq!(matrix.data(), &[42]); + assert_eq!(matrix[(0, 0)], 42); } - // Indexing & mutation + #[test] + #[should_panic(expected = "need at least one column")] + fn test_from_cols_empty_cols() { + let empty_cols: Vec> = vec![]; + Matrix::from_cols(empty_cols); + } #[test] - fn index_mut_works() { - let mut m = Matrix::from_cols(vec![vec![1, 2]]); - *m.get_mut(0, 0) = 10; - assert_eq!(m[(0, 0)], 10); + #[should_panic(expected = "need at least one row")] + fn test_from_cols_empty_rows() { + let empty_row: Vec> = vec![vec![], vec![]]; + Matrix::from_cols(empty_row); + } + + #[test] + #[should_panic(expected = "col 1 has len 2, expected 3")] + fn test_from_cols_mismatched_lengths() { + let cols_data = vec![vec![1, 2, 3], vec![4, 5], vec![6, 7, 8]]; + Matrix::from_cols(cols_data); + } + + #[test] + fn test_getters() { + let matrix = create_test_matrix(); + assert_eq!(matrix.rows(), 3); + assert_eq!(matrix.cols(), 3); + assert_eq!(matrix.data(), &[1, 2, 3, 4, 5, 6, 7, 8, 9]); + } + + #[test] + fn test_index_and_get() { + let matrix = create_test_matrix(); + assert_eq!(matrix[(0, 0)], 1); + assert_eq!(matrix[(1, 1)], 5); + assert_eq!(matrix[(2, 2)], 9); + + assert_eq!(*matrix.get(0, 0), 1); + assert_eq!(*matrix.get(1, 1), 5); + assert_eq!(*matrix.get(2, 2), 9); } #[test] #[should_panic(expected = "index out of bounds")] - fn index_out_of_bounds_panics() { - let m = Matrix::from_cols(vec![vec![1]]); - let _ = m[(1, 0)]; + fn test_index_out_of_bounds_row() { + let matrix = create_test_matrix(); // 3x3 + let _ = matrix[(3, 0)]; } - // Column swapping - #[test] - fn swap_columns_basic() { - let mut m = Matrix::from_cols(vec![vec![1, 2], vec![3, 4]]); - m.swap_columns(0, 1); - let expected = Matrix::from_cols(vec![vec![3, 4], vec![1, 2]]); - assert_eq!(m, expected); + #[should_panic(expected = "index out of bounds")] + fn test_index_out_of_bounds_col() { + let matrix = create_test_matrix(); // 3x3 + let _ = matrix[(0, 3)]; } - // ── Element‑wise ops ─ + #[test] + fn test_index_mut_and_get_mut() { + let mut matrix = create_test_matrix(); // 3x3 + + matrix[(0, 0)] = 10; + matrix[(1, 1)] = 20; + matrix[(2, 2)] = 30; + + assert_eq!(matrix[(0, 0)], 10); + assert_eq!(matrix[(1, 1)], 20); + assert_eq!(matrix[(2, 2)], 30); + + *matrix.get_mut(0, 1) = 15; + *matrix.get_mut(2, 1) = 25; + + assert_eq!(matrix[(0, 1)], 15); + assert_eq!(matrix[(2, 1)], 25); + + // Check underlying data consistency (column-major) + // Should be: + // 10 15 7 + // 2 20 8 + // 3 25 30 + assert_eq!(matrix.data(), &[10, 2, 3, 15, 20, 25, 7, 8, 30]); + } #[test] - fn elementwise_add_sub_mul_div() { - let a = Matrix::from_cols(vec![vec![1, 2], vec![3, 4]]); - let b = Matrix::from_cols(vec![vec![10, 20], vec![30, 40]]); + #[should_panic(expected = "index out of bounds")] + fn test_index_mut_out_of_bounds_row() { + let mut matrix = create_test_matrix(); // 3x3 + matrix[(3, 0)] = 99; + } - assert_eq!(&a + &b, Matrix::from_cols(vec![vec![11, 22], vec![33, 44]])); - assert_eq!(&b - &a, Matrix::from_cols(vec![vec![9, 18], vec![27, 36]])); + #[test] + #[should_panic(expected = "index out of bounds")] + fn test_index_mut_out_of_bounds_col() { + let mut matrix = create_test_matrix(); // 3x3 + matrix[(0, 3)] = 99; + } + + #[test] + fn test_column() { + let matrix = create_test_matrix_2x4(); // 2x4 + // 1 3 5 7 + // 2 4 6 8 + + assert_eq!(matrix.column(0), &[1, 2]); + assert_eq!(matrix.column(1), &[3, 4]); + assert_eq!(matrix.column(2), &[5, 6]); + assert_eq!(matrix.column(3), &[7, 8]); + } + + #[test] + #[should_panic(expected = "range end index")] + fn test_column_out_of_bounds() { + let matrix = create_test_matrix_2x4(); // 2x4 + matrix.column(4); + } + + #[test] + fn test_column_mut() { + let mut matrix = create_test_matrix_2x4(); // 2x4 + // 1 3 5 7 + // 2 4 6 8 + + let col1_mut = matrix.column_mut(1); + col1_mut[0] = 30; + col1_mut[1] = 40; + + let col3_mut = matrix.column_mut(3); + col3_mut[0] = 70; + + // Check changes via indexing + assert_eq!(matrix[(0, 1)], 30); + assert_eq!(matrix[(1, 1)], 40); + assert_eq!(matrix[(0, 3)], 70); + assert_eq!(matrix[(1, 3)], 8); // Unchanged + + // Check underlying data (column-major) + // Should be: + // 1 30 5 70 + // 2 40 6 8 + assert_eq!(matrix.data(), &[1, 2, 30, 40, 5, 6, 70, 8]); + } + + #[test] + #[should_panic(expected = "range end index")] + fn test_column_mut_out_of_bounds() { + let mut matrix = create_test_matrix_2x4(); // 2x4 + matrix.column_mut(4); + } + + #[test] + fn test_iter_columns() { + let matrix = create_test_matrix_2x4(); // 2x4 + // 1 3 5 7 + // 2 4 6 8 + + let cols: Vec<&[i32]> = matrix.iter_columns().collect(); + assert_eq!(cols.len(), 4); + assert_eq!(cols[0], &[1, 2]); + assert_eq!(cols[1], &[3, 4]); + assert_eq!(cols[2], &[5, 6]); + assert_eq!(cols[3], &[7, 8]); + } + + #[test] + fn test_iter_rows() { + let matrix = create_test_matrix_2x4(); // 2x4 + // 1 3 5 7 + // 2 4 6 8 + + let rows: Vec> = matrix + .iter_rows() + .map(|row| row.iter().cloned().collect()) + .collect(); + assert_eq!(rows.len(), 2); + assert_eq!(rows[0], vec![1, 3, 5, 7]); + assert_eq!(rows[1], vec![2, 4, 6, 8]); + } + + #[test] + fn test_matrix_row_get_and_iter() { + let matrix = create_test_matrix_2x4(); // 2x4 + // 1 3 5 7 + // 2 4 6 8 + + let row0 = matrix.iter_rows().next().unwrap(); + assert_eq!(*row0.get(0), 1); + assert_eq!(*row0.get(1), 3); + assert_eq!(*row0.get(3), 7); + let row0_vec: Vec = row0.iter().cloned().collect(); + assert_eq!(row0_vec, vec![1, 3, 5, 7]); + + let row1 = matrix.iter_rows().nth(1).unwrap(); + assert_eq!(*row1.get(0), 2); + assert_eq!(*row1.get(2), 6); + let row1_vec: Vec = row1.iter().cloned().collect(); + assert_eq!(row1_vec, vec![2, 4, 6, 8]); + } + + #[test] + fn test_swap_columns() { + let mut matrix = create_test_matrix(); // 3x3 + // 1 4 7 + // 2 5 8 + // 3 6 9 + + matrix.swap_columns(0, 2); // Swap first and last + + // Should be: + // 7 4 1 + // 8 5 2 + // 9 6 3 + + assert_eq!(matrix.rows(), 3); + assert_eq!(matrix.cols(), 3); + assert_eq!(matrix[(0, 0)], 7); + assert_eq!(matrix[(1, 0)], 8); + assert_eq!(matrix[(2, 0)], 9); + assert_eq!(matrix[(0, 1)], 4); // Middle col unchanged + assert_eq!(matrix[(1, 1)], 5); + assert_eq!(matrix[(2, 1)], 6); + assert_eq!(matrix[(0, 2)], 1); + assert_eq!(matrix[(1, 2)], 2); + assert_eq!(matrix[(2, 2)], 3); + + // Swap the same column (should do nothing) + let original_data = matrix.data().to_vec(); + matrix.swap_columns(1, 1); + assert_eq!(matrix.data(), &original_data); // Data should be identical + + // Check underlying data (column-major) after swap(0, 2) + assert_eq!(matrix.data(), &[7, 8, 9, 4, 5, 6, 1, 2, 3]); + } + + #[test] + #[should_panic(expected = "column index out of bounds")] + fn test_swap_columns_out_of_bounds() { + let mut matrix = create_test_matrix(); // 3x3 + matrix.swap_columns(0, 3); + } + + #[test] + fn test_delete_column() { + let mut matrix = create_test_matrix_2x4(); // 2x4 + // 1 3 5 7 + // 2 4 6 8 + + matrix.delete_column(1); // Delete the second column (index 1) + + // Should be: + // 1 5 7 + // 2 6 8 + + assert_eq!(matrix.rows(), 2); + assert_eq!(matrix.cols(), 3); + assert_eq!(matrix[(0, 0)], 1); + assert_eq!(matrix[(1, 0)], 2); + assert_eq!(matrix[(0, 1)], 5); + assert_eq!(matrix[(1, 1)], 6); + assert_eq!(matrix[(0, 2)], 7); + assert_eq!(matrix[(1, 2)], 8); + + // Check underlying data (column-major) + assert_eq!(matrix.data(), &[1, 2, 5, 6, 7, 8]); + + // Delete the first column + matrix.delete_column(0); + // Should be: + // 5 7 + // 6 8 + assert_eq!(matrix.rows(), 2); + assert_eq!(matrix.cols(), 2); + assert_eq!(matrix.data(), &[5, 6, 7, 8]); + + // Delete the last column + matrix.delete_column(1); + // Should be: + // 5 + // 6 + assert_eq!(matrix.rows(), 2); + assert_eq!(matrix.cols(), 1); + assert_eq!(matrix.data(), &[5, 6]); + + // Delete the only column + matrix.delete_column(0); + // Should be empty + assert_eq!(matrix.rows(), 2); // Rows stay the same + assert_eq!(matrix.cols(), 0); // Cols becomes 0 + assert_eq!(matrix.data(), &[]); + } + + #[test] + #[should_panic(expected = "column index out of bounds")] + fn test_delete_column_out_of_bounds() { + let mut matrix = create_test_matrix_2x4(); // 2x4 + matrix.delete_column(4); + } + + #[test] + fn test_delete_row() { + let mut matrix = create_test_matrix(); // 3x3 + // 1 4 7 + // 2 5 8 + // 3 6 9 + + matrix.delete_row(1); // Delete the second row (index 1) + + // Should be: + // 1 4 7 + // 3 6 9 + + assert_eq!(matrix.rows(), 2); + assert_eq!(matrix.cols(), 3); + assert_eq!(matrix[(0, 0)], 1); + assert_eq!(matrix[(1, 0)], 3); + assert_eq!(matrix[(0, 1)], 4); + assert_eq!(matrix[(1, 1)], 6); + assert_eq!(matrix[(0, 2)], 7); + assert_eq!(matrix[(1, 2)], 9); + + // Check underlying data (column-major) + // Original: [1, 2, 3, 4, 5, 6, 7, 8, 9] + // Delete row 1: [1, 3, 4, 6, 7, 9] + assert_eq!(matrix.data(), &[1, 3, 4, 6, 7, 9]); + + // Delete the first row + matrix.delete_row(0); + // Should be: + // 3 6 9 + assert_eq!(matrix.rows(), 1); + assert_eq!(matrix.cols(), 3); + assert_eq!(matrix.data(), &[3, 6, 9]); + + // Delete the last (and only) row + matrix.delete_row(0); + // Should be empty + assert_eq!(matrix.rows(), 0); // Rows becomes 0 + assert_eq!(matrix.cols(), 3); // Cols stay the same + assert_eq!(matrix.data(), &[]); + } + + #[test] + #[should_panic(expected = "row index out of bounds")] + fn test_delete_row_out_of_bounds() { + let mut matrix = create_test_matrix(); // 3x3 + matrix.delete_row(3); + } + + #[test] + fn test_add_column() { + let mut matrix = create_test_matrix_2x4(); // 2x4 + // 1 3 5 7 + // 2 4 6 8 + + let new_col = vec![9, 10]; + matrix.add_column(2, new_col); // Add at index 2 + + // Should be: + // 1 3 9 5 7 + // 2 4 10 6 8 + + assert_eq!(matrix.rows(), 2); + assert_eq!(matrix.cols(), 5); + assert_eq!(matrix[(0, 0)], 1); + assert_eq!(matrix[(1, 0)], 2); + assert_eq!(matrix[(0, 1)], 3); + assert_eq!(matrix[(1, 1)], 4); + assert_eq!(matrix[(0, 2)], 9); + assert_eq!(matrix[(1, 2)], 10); + assert_eq!(matrix[(0, 3)], 5); // Shifted + assert_eq!(matrix[(1, 3)], 6); + assert_eq!(matrix[(0, 4)], 7); // Shifted + assert_eq!(matrix[(1, 4)], 8); + + // Check underlying data (column-major) + // Original: [1, 2, 3, 4, 5, 6, 7, 8] + // Add [9, 10] at index 2: [1, 2, 3, 4, 9, 10, 5, 6, 7, 8] + assert_eq!(matrix.data(), &[1, 2, 3, 4, 9, 10, 5, 6, 7, 8]); + + // Add a column at the beginning + let new_col_start = vec![11, 12]; + matrix.add_column(0, new_col_start); + // Should be: + // 11 1 3 9 5 7 + // 12 2 4 10 6 8 + assert_eq!(matrix.rows(), 2); + assert_eq!(matrix.cols(), 6); + assert_eq!(matrix[(0, 0)], 11); + assert_eq!(matrix[(1, 0)], 12); + assert_eq!(matrix.data(), &[11, 12, 1, 2, 3, 4, 9, 10, 5, 6, 7, 8]); + + // Add a column at the end + let new_col_end = vec![13, 14]; + matrix.add_column(6, new_col_end); + // Should be: + // 11 1 3 9 5 7 13 + // 12 2 4 10 6 8 14 + assert_eq!(matrix.rows(), 2); + assert_eq!(matrix.cols(), 7); + assert_eq!(matrix[(0, 6)], 13); + assert_eq!(matrix[(1, 6)], 14); assert_eq!( - &a * &b, - Matrix::from_cols(vec![vec![10, 40], vec![90, 160]]) + matrix.data(), + &[11, 12, 1, 2, 3, 4, 9, 10, 5, 6, 7, 8, 13, 14] ); - assert_eq!(&b / &a, Matrix::from_cols(vec![vec![10, 10], vec![10, 10]])); - } - - // f64‑specific helpers - - #[test] - fn reductions_with_nan() { - let m = Matrix::from_cols(vec![vec![1.0, NAN, 3.0], vec![4.0, 5.0, NAN]]); - // Matrix: 3 rows × 2 columns - assert_vec_f64_eq(&m.sum_vertical(), &[4.0, 9.0]); - assert_vec_f64_eq(&m.sum_horizontal(), &[5.0, 5.0, 3.0]); - assert_vec_f64_eq(&m.prod_vertical(), &[3.0, 20.0]); - assert_vec_f64_eq(&m.prod_horizontal(), &[4.0, 5.0, 3.0]); - - assert_eq!(m.count_nan_vertical(), vec![1, 1]); - assert_eq!(m.count_nan_horizontal(), vec![0, 1, 1]); } #[test] - fn apply_axis_dispatch() { - let m = Matrix::from_cols(vec![vec![1.0, 2.0], vec![3.0, 4.0]]); - let col_sums = m.apply_axis(Axis::Col, |col| col.iter().sum::()); - let row_sums = m.apply_axis(Axis::Row, |row| row.iter().sum::()); - assert_vec_f64_eq(&col_sums, &[3.0, 7.0]); - assert_vec_f64_eq(&row_sums, &[4.0, 6.0]); + #[should_panic(expected = "column index out of bounds")] + fn test_add_column_out_of_bounds() { + let mut matrix = create_test_matrix_2x4(); // 2x4 + let new_col = vec![9, 10]; + matrix.add_column(5, new_col); // Index 5 is out of bounds for 4 columns } #[test] - fn is_nan_mask() { - let m = Matrix::from_cols(vec![vec![1.0, NAN]]); - let expected = Matrix::from_cols(vec![vec![false, true]]); - assert_eq!(m.is_nan(), expected); + #[should_panic(expected = "column length mismatch")] + fn test_add_column_length_mismatch() { + let mut matrix = create_test_matrix_2x4(); // 2x4 (2 rows) + let new_col = vec![9, 10, 11]; // Wrong length + matrix.add_column(0, new_col); } + + #[test] + fn test_add_row() { + let mut matrix = create_test_matrix_2x4(); // 2x4 + // 1 3 5 7 + // 2 4 6 8 + + let new_row = vec![9, 10, 11, 12]; + matrix.add_row(1, new_row); // Add at index 1 + + // Should be: + // 1 3 5 7 + // 9 10 11 12 + // 2 4 6 8 + + assert_eq!(matrix.rows(), 3); + assert_eq!(matrix.cols(), 4); + + assert_eq!(matrix[(0, 0)], 1); + assert_eq!(matrix[(0, 1)], 3); + assert_eq!(matrix[(0, 2)], 5); + assert_eq!(matrix[(0, 3)], 7); + assert_eq!(matrix[(1, 0)], 9); + assert_eq!(matrix[(1, 1)], 10); + assert_eq!(matrix[(1, 2)], 11); + assert_eq!(matrix[(1, 3)], 12); + assert_eq!(matrix[(2, 0)], 2); + assert_eq!(matrix[(2, 1)], 4); + assert_eq!(matrix[(2, 2)], 6); + assert_eq!(matrix[(2, 3)], 8); + + // Check underlying data (column-major) + // Original: [1, 2, 3, 4, 5, 6, 7, 8] (rows 0, 1) + // Add [9, 10, 11, 12] at index 1 (new row will be index 1, original row 1 becomes index 2) + // Col 0: [1, 9, 2] + // Col 1: [3, 10, 4] + // Col 2: [5, 11, 6] + // Col 3: [7, 12, 8] + // Data: [1, 9, 2, 3, 10, 4, 5, 11, 6, 7, 12, 8] + assert_eq!(matrix.data(), &[1, 9, 2, 3, 10, 4, 5, 11, 6, 7, 12, 8]); + + // Add a row at the beginning + let new_row_start = vec![13, 14, 15, 16]; + matrix.add_row(0, new_row_start); + // Should be: + // 13 14 15 16 + // 1 3 5 7 + // 9 10 11 12 + // 2 4 6 8 + assert_eq!(matrix.rows(), 4); + assert_eq!(matrix.cols(), 4); + assert_eq!(matrix[(0, 0)], 13); + assert_eq!(matrix[(0, 3)], 16); + // Check some existing elements to ensure they shifted correctly + assert_eq!(matrix[(1, 0)], 1); + assert_eq!(matrix[(2, 1)], 10); + assert_eq!(matrix[(3, 3)], 8); + + // Add a row at the end + let new_row_end = vec![17, 18, 19, 20]; + matrix.add_row(4, new_row_end); + // Should be: + // 13 14 15 16 + // 1 3 5 7 + // 9 10 11 12 + // 2 4 6 8 + // 17 18 19 20 + assert_eq!(matrix.rows(), 5); + assert_eq!(matrix.cols(), 4); + assert_eq!(matrix[(4, 0)], 17); + assert_eq!(matrix[(4, 3)], 20); + } + + #[test] + #[should_panic(expected = "row index out of bounds")] + fn test_add_row_out_of_bounds() { + let mut matrix = create_test_matrix_2x4(); // 2x4 + let new_row = vec![9, 10, 11, 12]; + matrix.add_row(3, new_row); // Index 3 is out of bounds for 2 rows + } + + #[test] + #[should_panic(expected = "row length mismatch")] + fn test_add_row_length_mismatch() { + let mut matrix = create_test_matrix_2x4(); // 2x4 (4 cols) + let new_row = vec![9, 10, 11]; // Wrong length + matrix.add_row(0, new_row); + } + + #[test] + fn test_elementwise_add() { + let matrix1 = create_test_matrix(); // 3x3 + let matrix2 = Matrix::from_vec(vec![9, 8, 7, 6, 5, 4, 3, 2, 1], 3, 3); // 3x3 + + let result = &matrix1 + &matrix2; + + assert_eq!(result.rows(), 3); + assert_eq!(result.cols(), 3); + + // Expected: + // 1+9 4+6 7+3 => 10 10 10 + // 2+8 5+5 8+2 => 10 10 10 + // 3+7 6+4 9+1 => 10 10 10 + // Column-major data: [10, 10, 10, 10, 10, 10, 10, 10, 10] + assert_eq!(result.data(), &[10, 10, 10, 10, 10, 10, 10, 10, 10]); + assert_eq!(result[(0, 0)], 10); + assert_eq!(result[(1, 1)], 10); + assert_eq!(result[(2, 2)], 10); + } + + #[test] + fn test_elementwise_sub() { + let matrix1 = create_test_matrix(); // 3x3 + let matrix2 = Matrix::from_vec(vec![1, 1, 1, 2, 2, 2, 3, 3, 3], 3, 3); // 3x3 + + let result = &matrix1 - &matrix2; + + assert_eq!(result.rows(), 3); + assert_eq!(result.cols(), 3); + + // Expected: + // 1-1 4-2 7-3 => 0 2 4 + // 2-1 5-2 8-3 => 1 3 5 + // 3-1 6-2 9-3 => 2 4 6 + // Column-major data: [0, 1, 2, 2, 3, 4, 4, 5, 6] + assert_eq!(result.data(), &[0, 1, 2, 2, 3, 4, 4, 5, 6]); + assert_eq!(result[(0, 0)], 0); + assert_eq!(result[(1, 1)], 3); + assert_eq!(result[(2, 2)], 6); + } + + #[test] + fn test_elementwise_mul() { + let matrix1 = create_test_matrix(); // 3x3 + let matrix2 = Matrix::from_vec(vec![1, 2, 3, 1, 2, 3, 1, 2, 3], 3, 3); // 3x3 + + let result = &matrix1 * &matrix2; + + assert_eq!(result.rows(), 3); + assert_eq!(result.cols(), 3); + + // Expected: + // 1*1 4*1 7*1 => 1 4 7 + // 2*2 5*2 8*2 => 4 10 16 + // 3*3 6*3 9*3 => 9 18 27 + // Column-major data: [1, 4, 9, 4, 10, 18, 7, 16, 27] + assert_eq!(result.data(), &[1, 4, 9, 4, 10, 18, 7, 16, 27]); + assert_eq!(result[(0, 0)], 1); + assert_eq!(result[(1, 1)], 10); + assert_eq!(result[(2, 2)], 27); + } + + #[test] + fn test_elementwise_div() { + let matrix1 = create_test_matrix(); // 3x3 + let matrix2 = Matrix::from_vec(vec![1, 1, 1, 2, 2, 2, 7, 8, 9], 3, 3); // 3x3 + + let result = &matrix1 / &matrix2; // Integer division + + assert_eq!(result.rows(), 3); + assert_eq!(result.cols(), 3); + + // Expected: + // 1/1 4/2 7/7 => 1 2 1 + // 2/1 5/2 8/8 => 2 2 1 (integer division) + // 3/1 6/2 9/9 => 3 3 1 + // Column-major data: [1, 2, 3, 2, 2, 3, 1, 1, 1] + assert_eq!(result.data(), &[1, 2, 3, 2, 2, 3, 1, 1, 1]); + assert_eq!(result[(0, 0)], 1); + assert_eq!(result[(1, 1)], 2); + assert_eq!(result[(2, 2)], 1); + } + + #[test] + #[should_panic(expected = "row count mismatch")] + fn test_elementwise_op_row_mismatch() { + let matrix1 = create_test_matrix(); // 3x3 + let matrix2 = create_test_matrix_2x4(); // 2x4 + let _ = &matrix1 + &matrix2; // Should panic + } + + #[test] + #[should_panic(expected = "row count mismatch")] + fn test_elementwise_op_col_mismatch() { + let matrix1 = create_test_matrix(); // 3x3 + let matrix2 = create_test_matrix_2x4(); // 2x4 + let _ = &matrix1 * &matrix2; // Should panic + } + + #[test] + fn test_bitwise_and() { + let data1 = vec![true, false, true, false, true, false]; // 2x3 + let data2 = vec![true, true, false, false, true, true]; // 2x3 + let matrix1 = BoolMatrix::from_vec(data1, 2, 3); + let matrix2 = BoolMatrix::from_vec(data2, 2, 3); + + // Expected column-major results: + // T & T = T + // F & T = F + // T & F = F + // F & F = F + // T & T = T + // F & T = F + // Data: [T, F, F, F, T, F] + let expected_data = vec![true, false, false, false, true, false]; + let expected_matrix = BoolMatrix::from_vec(expected_data, 2, 3); + + let result = &matrix1 & &matrix2; + assert_eq!(result, expected_matrix); + } + + #[test] + fn test_bitwise_or() { + let data1 = vec![true, false, true, false, true, false]; // 2x3 + let data2 = vec![true, true, false, false, true, true]; // 2x3 + let matrix1 = BoolMatrix::from_vec(data1, 2, 3); + let matrix2 = BoolMatrix::from_vec(data2, 2, 3); + + // Expected column-major results: + // T | T = T + // F | T = T + // T | F = T + // F | F = F + // T | T = T + // F | T = T + // Data: [T, T, T, F, T, T] + let expected_data = vec![true, true, true, false, true, true]; + let expected_matrix = BoolMatrix::from_vec(expected_data, 2, 3); + + let result = &matrix1 | &matrix2; + assert_eq!(result, expected_matrix); + } + + #[test] + fn test_bitwise_xor() { + let data1 = vec![true, false, true, false, true, false]; // 2x3 + let data2 = vec![true, true, false, false, true, true]; // 2x3 + let matrix1 = BoolMatrix::from_vec(data1, 2, 3); + let matrix2 = BoolMatrix::from_vec(data2, 2, 3); + + // Expected column-major results: + // T ^ T = F + // F ^ T = T + // T ^ F = T + // F ^ F = F + // T ^ T = F + // F ^ T = T + // Data: [F, T, T, F, F, T] + let expected_data = vec![false, true, true, false, false, true]; + let expected_matrix = BoolMatrix::from_vec(expected_data, 2, 3); + + let result = &matrix1 ^ &matrix2; + assert_eq!(result, expected_matrix); + } + + #[test] + fn test_bitwise_not() { + let data = vec![true, false, true, false, true, false]; // 2x3 + let matrix = BoolMatrix::from_vec(data, 2, 3); + + // Expected column-major results: + // !T = F + // !F = T + // !T = F + // !F = T + // !T = F + // !F = T + // Data: [F, T, F, T, F, T] + let expected_data = vec![false, true, false, true, false, true]; + let expected_matrix = BoolMatrix::from_vec(expected_data, 2, 3); + + let result = !matrix; // Not consumes the matrix + assert_eq!(result, expected_matrix); + } + + #[test] + #[should_panic(expected = "col count mismatch")] + fn test_bitwise_op_row_mismatch() { + let data1 = vec![true, false, true, false]; // 2x2 + let data2 = vec![true, true, false, false, true, true]; // 2x3 + let matrix1 = BoolMatrix::from_vec(data1, 2, 2); + let matrix2 = BoolMatrix::from_vec(data2, 2, 3); + let _ = &matrix1 & &matrix2; // Should panic + } + + #[test] + #[should_panic(expected = "col count mismatch")] + fn test_bitwise_op_col_mismatch() { + let data1 = vec![true, false, true, false]; // 2x2 + let data2 = vec![true, true, false, false, true, true]; // 2x3 + let matrix1 = BoolMatrix::from_vec(data1, 2, 2); + let matrix2 = BoolMatrix::from_vec(data2, 2, 3); + let _ = &matrix1 | &matrix2; // Should panic + } + + // Test with String type (requires Clone, PartialEq) + #[test] + fn test_string_matrix() { + let data = vec![ + "a".to_string(), + "b".to_string(), + "c".to_string(), + "d".to_string(), + ]; + let matrix = StringMatrix::from_vec(data.clone(), 2, 2); // 2x2 + + assert_eq!(matrix[(0, 0)], "a".to_string()); + assert_eq!(matrix[(1, 0)], "b".to_string()); + assert_eq!(matrix[(0, 1)], "c".to_string()); + assert_eq!(matrix[(1, 1)], "d".to_string()); + + // Test modification + let mut matrix = matrix; + matrix[(0, 0)] = "hello".to_string(); + assert_eq!(matrix[(0, 0)], "hello".to_string()); + + // Test add_column (requires Clone) + let new_col = vec!["e".to_string(), "f".to_string()]; + matrix.add_column(1, new_col); // Add at index 1 + // Should be: + // hello c d + // b e f + assert_eq!(matrix.rows(), 2); + assert_eq!(matrix.cols(), 3); + assert_eq!(matrix[(0, 0)], "hello".to_string()); + assert_eq!(matrix[(1, 0)], "b".to_string()); + assert_eq!(matrix[(0, 1)], "e".to_string()); // New col + assert_eq!(matrix[(1, 1)], "f".to_string()); // New col + assert_eq!(matrix[(0, 2)], "c".to_string()); // Shifted + assert_eq!(matrix[(1, 2)], "d".to_string()); // Shifted + + // Test add_row (requires Clone) + let new_row = vec!["g".to_string(), "h".to_string(), "i".to_string()]; + matrix.add_row(0, new_row); // Add at index 0 + // Should be: + // g h i + // hello e c + // b f d + assert_eq!(matrix.rows(), 3); + assert_eq!(matrix.cols(), 3); + assert_eq!(matrix[(0, 0)], "g".to_string()); + assert_eq!(matrix[(0, 1)], "h".to_string()); + assert_eq!(matrix[(0, 2)], "i".to_string()); + assert_eq!(matrix[(1, 0)], "hello".to_string()); // Shifted + assert_eq!(matrix[(2, 2)], "d".to_string()); // Shifted + } + + #[test] + fn test_float_matrix_ops() { + let data1 = vec![1.0, 2.0, 3.0, 4.0]; // 2x2 + let data2 = vec![0.5, 1.5, 2.5, 3.5]; // 2x2 + let matrix1 = FloatMatrix::from_vec(data1, 2, 2); + let matrix2 = FloatMatrix::from_vec(data2, 2, 2); + + let sum = &matrix1 + &matrix2; + let diff = &matrix1 - &matrix2; + let prod = &matrix1 * &matrix2; + let div = &matrix1 / &matrix2; + + // Check sums (col-major): [1.5, 3.5, 5.5, 7.5] + assert_eq!(sum.data(), &[1.5, 3.5, 5.5, 7.5]); + + // Check diffs (col-major): [0.5, 0.5, 0.5, 0.5] + assert_eq!(diff.data(), &[0.5, 0.5, 0.5, 0.5]); + + // Check prods (col-major): [0.5, 3.0, 7.5, 14.0] + assert_eq!(prod.data(), &[0.5, 3.0, 7.5, 14.0]); + + // Check divs (col-major): [2.0, 1.333..., 1.2, 1.14...] + // Using element access for more specific checks on floating point results + assert_eq!(div.rows(), 2); + assert_eq!(div.cols(), 2); + assert!((div[(0, 0)] - 1.0 / 0.5).abs() < 1e-9); // 2.0 + assert!((div[(1, 0)] - 2.0 / 1.5).abs() < 1e-9); // 1.333... + assert!((div[(0, 1)] - 3.0 / 2.5).abs() < 1e-9); // 1.2 + assert!((div[(1, 1)] - 4.0 / 3.5).abs() < 1e-9); // 1.14... + } + + // Axis enum doesn't have logic, no tests needed directly, but its presence is verified by compilation. }