implement tests for the matrix struct

This commit is contained in:
Palash Tyagi 2025-04-18 22:39:38 +01:00
parent f0d92d5c63
commit 1ccdee45ed

View File

@ -1,124 +1,890 @@
// Unit tests for Matrix<T> and its f64specific helpers.
#[cfg(test)]
mod tests {
use rustframe::frame::mat::{Axis, Matrix};
use std::f64::NAN;
use rustframe::matrix::{BoolMatrix, FloatMatrix, Matrix, StringMatrix};
const EPS: f64 = 1e-12;
fn assert_vec_f64_eq(a: &[f64], b: &[f64]) {
assert_eq!(a.len(), b.len(), "length mismatch");
for (i, (x, y)) in a.iter().zip(b.iter()).enumerate() {
assert!((*x - *y).abs() < EPS, "index {i} differs: {x} vs {y}");
}
// Helper function to create a basic Matrix for testing
fn create_test_matrix() -> Matrix<i32> {
// Column-major data:
// 1 4 7
// 2 5 8
// 3 6 9
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
Matrix::from_vec(data, 3, 3)
}
// Constructors
// Another helper for a different size
fn create_test_matrix_2x4() -> Matrix<i32> {
// Column-major data:
// 1 3 5 7
// 2 4 6 8
let data = vec![1, 2, 3, 4, 5, 6, 7, 8];
Matrix::from_vec(data, 2, 4)
}
#[test]
fn from_cols_basic() {
let m = Matrix::from_cols(vec![vec![1, 2], vec![3, 4]]);
assert_eq!(m.rows(), 2);
assert_eq!(m.cols(), 2);
assert_eq!(m[(0, 0)], 1);
assert_eq!(m[(1, 0)], 2);
assert_eq!(m[(0, 1)], 3);
assert_eq!(m[(1, 1)], 4);
fn test_from_vec_basic() {
let data = vec![1, 2, 3, 4, 5, 6]; // 2 rows, 3 cols (column-major)
let matrix = Matrix::from_vec(data, 2, 3);
assert_eq!(matrix.rows(), 2);
assert_eq!(matrix.cols(), 3);
assert_eq!(matrix.data(), &[1, 2, 3, 4, 5, 6]);
// Check some elements
assert_eq!(matrix[(0, 0)], 1); // First row, first col
assert_eq!(matrix[(1, 0)], 2); // Second row, first col
assert_eq!(matrix[(0, 1)], 3); // First row, second col
assert_eq!(matrix[(1, 2)], 6); // Second row, third col
}
#[test]
#[should_panic(expected = "data length mismatch")]
fn test_from_vec_wrong_length() {
let data = vec![1, 2, 3, 4, 5]; // Should be 6 for 2x3
Matrix::from_vec(data, 2, 3);
}
#[test]
#[should_panic(expected = "need at least one row")]
fn test_from_vec_zero_rows() {
let data = vec![1, 2, 3];
Matrix::from_vec(data, 0, 3);
}
#[test]
#[should_panic(expected = "need at least one column")]
fn from_cols_zero_columns_panics() {
let _ = Matrix::<i32>::from_cols(vec![]);
fn test_from_vec_zero_cols() {
let data = vec![1, 2, 3];
Matrix::from_vec(data, 3, 0);
}
#[test]
#[should_panic(expected = "col 1")]
fn from_cols_mismatched_lengths_panics() {
let _ = Matrix::from_cols(vec![vec![1, 2], vec![3]]);
fn test_from_cols_basic() {
// Representing:
// 1 4 7
// 2 5 8
// 3 6 9
let cols_data = vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]];
let matrix = Matrix::from_cols(cols_data);
assert_eq!(matrix.rows(), 3);
assert_eq!(matrix.cols(), 3);
// Internal data should be column-major
assert_eq!(matrix.data(), &[1, 2, 3, 4, 5, 6, 7, 8, 9]);
// Check some elements
assert_eq!(matrix[(0, 0)], 1);
assert_eq!(matrix[(2, 0)], 3);
assert_eq!(matrix[(1, 1)], 5);
assert_eq!(matrix[(0, 2)], 7);
}
#[test]
fn from_vec_basic() {
// columnmajor order: (r,c) = value
// (0,0)=1 (1,0)=2 (0,1)=3 (1,1)=4 (0,2)=5 (1,2)=6
let data = vec![1, 2, 3, 4, 5, 6];
let m = Matrix::from_vec(data, 2, 3);
assert_eq!(m[(0, 2)], 5);
assert_eq!(m[(1, 2)], 6);
fn test_from_cols_1x1() {
let cols_data = vec![vec![42]];
let matrix = Matrix::from_cols(cols_data);
assert_eq!(matrix.rows(), 1);
assert_eq!(matrix.cols(), 1);
assert_eq!(matrix.data(), &[42]);
assert_eq!(matrix[(0, 0)], 42);
}
// Indexing & mutation
#[test]
#[should_panic(expected = "need at least one column")]
fn test_from_cols_empty_cols() {
let empty_cols: Vec<Vec<i32>> = vec![];
Matrix::from_cols(empty_cols);
}
#[test]
fn index_mut_works() {
let mut m = Matrix::from_cols(vec![vec![1, 2]]);
*m.get_mut(0, 0) = 10;
assert_eq!(m[(0, 0)], 10);
#[should_panic(expected = "need at least one row")]
fn test_from_cols_empty_rows() {
let empty_row: Vec<Vec<String>> = vec![vec![], vec![]];
Matrix::from_cols(empty_row);
}
#[test]
#[should_panic(expected = "col 1 has len 2, expected 3")]
fn test_from_cols_mismatched_lengths() {
let cols_data = vec![vec![1, 2, 3], vec![4, 5], vec![6, 7, 8]];
Matrix::from_cols(cols_data);
}
#[test]
fn test_getters() {
let matrix = create_test_matrix();
assert_eq!(matrix.rows(), 3);
assert_eq!(matrix.cols(), 3);
assert_eq!(matrix.data(), &[1, 2, 3, 4, 5, 6, 7, 8, 9]);
}
#[test]
fn test_index_and_get() {
let matrix = create_test_matrix();
assert_eq!(matrix[(0, 0)], 1);
assert_eq!(matrix[(1, 1)], 5);
assert_eq!(matrix[(2, 2)], 9);
assert_eq!(*matrix.get(0, 0), 1);
assert_eq!(*matrix.get(1, 1), 5);
assert_eq!(*matrix.get(2, 2), 9);
}
#[test]
#[should_panic(expected = "index out of bounds")]
fn index_out_of_bounds_panics() {
let m = Matrix::from_cols(vec![vec![1]]);
let _ = m[(1, 0)];
fn test_index_out_of_bounds_row() {
let matrix = create_test_matrix(); // 3x3
let _ = matrix[(3, 0)];
}
// Column swapping
#[test]
fn swap_columns_basic() {
let mut m = Matrix::from_cols(vec![vec![1, 2], vec![3, 4]]);
m.swap_columns(0, 1);
let expected = Matrix::from_cols(vec![vec![3, 4], vec![1, 2]]);
assert_eq!(m, expected);
#[should_panic(expected = "index out of bounds")]
fn test_index_out_of_bounds_col() {
let matrix = create_test_matrix(); // 3x3
let _ = matrix[(0, 3)];
}
// ── Elementwise ops ─
#[test]
fn test_index_mut_and_get_mut() {
let mut matrix = create_test_matrix(); // 3x3
matrix[(0, 0)] = 10;
matrix[(1, 1)] = 20;
matrix[(2, 2)] = 30;
assert_eq!(matrix[(0, 0)], 10);
assert_eq!(matrix[(1, 1)], 20);
assert_eq!(matrix[(2, 2)], 30);
*matrix.get_mut(0, 1) = 15;
*matrix.get_mut(2, 1) = 25;
assert_eq!(matrix[(0, 1)], 15);
assert_eq!(matrix[(2, 1)], 25);
// Check underlying data consistency (column-major)
// Should be:
// 10 15 7
// 2 20 8
// 3 25 30
assert_eq!(matrix.data(), &[10, 2, 3, 15, 20, 25, 7, 8, 30]);
}
#[test]
fn elementwise_add_sub_mul_div() {
let a = Matrix::from_cols(vec![vec![1, 2], vec![3, 4]]);
let b = Matrix::from_cols(vec![vec![10, 20], vec![30, 40]]);
#[should_panic(expected = "index out of bounds")]
fn test_index_mut_out_of_bounds_row() {
let mut matrix = create_test_matrix(); // 3x3
matrix[(3, 0)] = 99;
}
assert_eq!(&a + &b, Matrix::from_cols(vec![vec![11, 22], vec![33, 44]]));
assert_eq!(&b - &a, Matrix::from_cols(vec![vec![9, 18], vec![27, 36]]));
#[test]
#[should_panic(expected = "index out of bounds")]
fn test_index_mut_out_of_bounds_col() {
let mut matrix = create_test_matrix(); // 3x3
matrix[(0, 3)] = 99;
}
#[test]
fn test_column() {
let matrix = create_test_matrix_2x4(); // 2x4
// 1 3 5 7
// 2 4 6 8
assert_eq!(matrix.column(0), &[1, 2]);
assert_eq!(matrix.column(1), &[3, 4]);
assert_eq!(matrix.column(2), &[5, 6]);
assert_eq!(matrix.column(3), &[7, 8]);
}
#[test]
#[should_panic(expected = "range end index")]
fn test_column_out_of_bounds() {
let matrix = create_test_matrix_2x4(); // 2x4
matrix.column(4);
}
#[test]
fn test_column_mut() {
let mut matrix = create_test_matrix_2x4(); // 2x4
// 1 3 5 7
// 2 4 6 8
let col1_mut = matrix.column_mut(1);
col1_mut[0] = 30;
col1_mut[1] = 40;
let col3_mut = matrix.column_mut(3);
col3_mut[0] = 70;
// Check changes via indexing
assert_eq!(matrix[(0, 1)], 30);
assert_eq!(matrix[(1, 1)], 40);
assert_eq!(matrix[(0, 3)], 70);
assert_eq!(matrix[(1, 3)], 8); // Unchanged
// Check underlying data (column-major)
// Should be:
// 1 30 5 70
// 2 40 6 8
assert_eq!(matrix.data(), &[1, 2, 30, 40, 5, 6, 70, 8]);
}
#[test]
#[should_panic(expected = "range end index")]
fn test_column_mut_out_of_bounds() {
let mut matrix = create_test_matrix_2x4(); // 2x4
matrix.column_mut(4);
}
#[test]
fn test_iter_columns() {
let matrix = create_test_matrix_2x4(); // 2x4
// 1 3 5 7
// 2 4 6 8
let cols: Vec<&[i32]> = matrix.iter_columns().collect();
assert_eq!(cols.len(), 4);
assert_eq!(cols[0], &[1, 2]);
assert_eq!(cols[1], &[3, 4]);
assert_eq!(cols[2], &[5, 6]);
assert_eq!(cols[3], &[7, 8]);
}
#[test]
fn test_iter_rows() {
let matrix = create_test_matrix_2x4(); // 2x4
// 1 3 5 7
// 2 4 6 8
let rows: Vec<Vec<i32>> = matrix
.iter_rows()
.map(|row| row.iter().cloned().collect())
.collect();
assert_eq!(rows.len(), 2);
assert_eq!(rows[0], vec![1, 3, 5, 7]);
assert_eq!(rows[1], vec![2, 4, 6, 8]);
}
#[test]
fn test_matrix_row_get_and_iter() {
let matrix = create_test_matrix_2x4(); // 2x4
// 1 3 5 7
// 2 4 6 8
let row0 = matrix.iter_rows().next().unwrap();
assert_eq!(*row0.get(0), 1);
assert_eq!(*row0.get(1), 3);
assert_eq!(*row0.get(3), 7);
let row0_vec: Vec<i32> = row0.iter().cloned().collect();
assert_eq!(row0_vec, vec![1, 3, 5, 7]);
let row1 = matrix.iter_rows().nth(1).unwrap();
assert_eq!(*row1.get(0), 2);
assert_eq!(*row1.get(2), 6);
let row1_vec: Vec<i32> = row1.iter().cloned().collect();
assert_eq!(row1_vec, vec![2, 4, 6, 8]);
}
#[test]
fn test_swap_columns() {
let mut matrix = create_test_matrix(); // 3x3
// 1 4 7
// 2 5 8
// 3 6 9
matrix.swap_columns(0, 2); // Swap first and last
// Should be:
// 7 4 1
// 8 5 2
// 9 6 3
assert_eq!(matrix.rows(), 3);
assert_eq!(matrix.cols(), 3);
assert_eq!(matrix[(0, 0)], 7);
assert_eq!(matrix[(1, 0)], 8);
assert_eq!(matrix[(2, 0)], 9);
assert_eq!(matrix[(0, 1)], 4); // Middle col unchanged
assert_eq!(matrix[(1, 1)], 5);
assert_eq!(matrix[(2, 1)], 6);
assert_eq!(matrix[(0, 2)], 1);
assert_eq!(matrix[(1, 2)], 2);
assert_eq!(matrix[(2, 2)], 3);
// Swap the same column (should do nothing)
let original_data = matrix.data().to_vec();
matrix.swap_columns(1, 1);
assert_eq!(matrix.data(), &original_data); // Data should be identical
// Check underlying data (column-major) after swap(0, 2)
assert_eq!(matrix.data(), &[7, 8, 9, 4, 5, 6, 1, 2, 3]);
}
#[test]
#[should_panic(expected = "column index out of bounds")]
fn test_swap_columns_out_of_bounds() {
let mut matrix = create_test_matrix(); // 3x3
matrix.swap_columns(0, 3);
}
#[test]
fn test_delete_column() {
let mut matrix = create_test_matrix_2x4(); // 2x4
// 1 3 5 7
// 2 4 6 8
matrix.delete_column(1); // Delete the second column (index 1)
// Should be:
// 1 5 7
// 2 6 8
assert_eq!(matrix.rows(), 2);
assert_eq!(matrix.cols(), 3);
assert_eq!(matrix[(0, 0)], 1);
assert_eq!(matrix[(1, 0)], 2);
assert_eq!(matrix[(0, 1)], 5);
assert_eq!(matrix[(1, 1)], 6);
assert_eq!(matrix[(0, 2)], 7);
assert_eq!(matrix[(1, 2)], 8);
// Check underlying data (column-major)
assert_eq!(matrix.data(), &[1, 2, 5, 6, 7, 8]);
// Delete the first column
matrix.delete_column(0);
// Should be:
// 5 7
// 6 8
assert_eq!(matrix.rows(), 2);
assert_eq!(matrix.cols(), 2);
assert_eq!(matrix.data(), &[5, 6, 7, 8]);
// Delete the last column
matrix.delete_column(1);
// Should be:
// 5
// 6
assert_eq!(matrix.rows(), 2);
assert_eq!(matrix.cols(), 1);
assert_eq!(matrix.data(), &[5, 6]);
// Delete the only column
matrix.delete_column(0);
// Should be empty
assert_eq!(matrix.rows(), 2); // Rows stay the same
assert_eq!(matrix.cols(), 0); // Cols becomes 0
assert_eq!(matrix.data(), &[]);
}
#[test]
#[should_panic(expected = "column index out of bounds")]
fn test_delete_column_out_of_bounds() {
let mut matrix = create_test_matrix_2x4(); // 2x4
matrix.delete_column(4);
}
#[test]
fn test_delete_row() {
let mut matrix = create_test_matrix(); // 3x3
// 1 4 7
// 2 5 8
// 3 6 9
matrix.delete_row(1); // Delete the second row (index 1)
// Should be:
// 1 4 7
// 3 6 9
assert_eq!(matrix.rows(), 2);
assert_eq!(matrix.cols(), 3);
assert_eq!(matrix[(0, 0)], 1);
assert_eq!(matrix[(1, 0)], 3);
assert_eq!(matrix[(0, 1)], 4);
assert_eq!(matrix[(1, 1)], 6);
assert_eq!(matrix[(0, 2)], 7);
assert_eq!(matrix[(1, 2)], 9);
// Check underlying data (column-major)
// Original: [1, 2, 3, 4, 5, 6, 7, 8, 9]
// Delete row 1: [1, 3, 4, 6, 7, 9]
assert_eq!(matrix.data(), &[1, 3, 4, 6, 7, 9]);
// Delete the first row
matrix.delete_row(0);
// Should be:
// 3 6 9
assert_eq!(matrix.rows(), 1);
assert_eq!(matrix.cols(), 3);
assert_eq!(matrix.data(), &[3, 6, 9]);
// Delete the last (and only) row
matrix.delete_row(0);
// Should be empty
assert_eq!(matrix.rows(), 0); // Rows becomes 0
assert_eq!(matrix.cols(), 3); // Cols stay the same
assert_eq!(matrix.data(), &[]);
}
#[test]
#[should_panic(expected = "row index out of bounds")]
fn test_delete_row_out_of_bounds() {
let mut matrix = create_test_matrix(); // 3x3
matrix.delete_row(3);
}
#[test]
fn test_add_column() {
let mut matrix = create_test_matrix_2x4(); // 2x4
// 1 3 5 7
// 2 4 6 8
let new_col = vec![9, 10];
matrix.add_column(2, new_col); // Add at index 2
// Should be:
// 1 3 9 5 7
// 2 4 10 6 8
assert_eq!(matrix.rows(), 2);
assert_eq!(matrix.cols(), 5);
assert_eq!(matrix[(0, 0)], 1);
assert_eq!(matrix[(1, 0)], 2);
assert_eq!(matrix[(0, 1)], 3);
assert_eq!(matrix[(1, 1)], 4);
assert_eq!(matrix[(0, 2)], 9);
assert_eq!(matrix[(1, 2)], 10);
assert_eq!(matrix[(0, 3)], 5); // Shifted
assert_eq!(matrix[(1, 3)], 6);
assert_eq!(matrix[(0, 4)], 7); // Shifted
assert_eq!(matrix[(1, 4)], 8);
// Check underlying data (column-major)
// Original: [1, 2, 3, 4, 5, 6, 7, 8]
// Add [9, 10] at index 2: [1, 2, 3, 4, 9, 10, 5, 6, 7, 8]
assert_eq!(matrix.data(), &[1, 2, 3, 4, 9, 10, 5, 6, 7, 8]);
// Add a column at the beginning
let new_col_start = vec![11, 12];
matrix.add_column(0, new_col_start);
// Should be:
// 11 1 3 9 5 7
// 12 2 4 10 6 8
assert_eq!(matrix.rows(), 2);
assert_eq!(matrix.cols(), 6);
assert_eq!(matrix[(0, 0)], 11);
assert_eq!(matrix[(1, 0)], 12);
assert_eq!(matrix.data(), &[11, 12, 1, 2, 3, 4, 9, 10, 5, 6, 7, 8]);
// Add a column at the end
let new_col_end = vec![13, 14];
matrix.add_column(6, new_col_end);
// Should be:
// 11 1 3 9 5 7 13
// 12 2 4 10 6 8 14
assert_eq!(matrix.rows(), 2);
assert_eq!(matrix.cols(), 7);
assert_eq!(matrix[(0, 6)], 13);
assert_eq!(matrix[(1, 6)], 14);
assert_eq!(
&a * &b,
Matrix::from_cols(vec![vec![10, 40], vec![90, 160]])
matrix.data(),
&[11, 12, 1, 2, 3, 4, 9, 10, 5, 6, 7, 8, 13, 14]
);
assert_eq!(&b / &a, Matrix::from_cols(vec![vec![10, 10], vec![10, 10]]));
}
// f64specific helpers
#[test]
fn reductions_with_nan() {
let m = Matrix::from_cols(vec![vec![1.0, NAN, 3.0], vec![4.0, 5.0, NAN]]);
// Matrix: 3 rows × 2 columns
assert_vec_f64_eq(&m.sum_vertical(), &[4.0, 9.0]);
assert_vec_f64_eq(&m.sum_horizontal(), &[5.0, 5.0, 3.0]);
assert_vec_f64_eq(&m.prod_vertical(), &[3.0, 20.0]);
assert_vec_f64_eq(&m.prod_horizontal(), &[4.0, 5.0, 3.0]);
assert_eq!(m.count_nan_vertical(), vec![1, 1]);
assert_eq!(m.count_nan_horizontal(), vec![0, 1, 1]);
}
#[test]
fn apply_axis_dispatch() {
let m = Matrix::from_cols(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
let col_sums = m.apply_axis(Axis::Col, |col| col.iter().sum::<f64>());
let row_sums = m.apply_axis(Axis::Row, |row| row.iter().sum::<f64>());
assert_vec_f64_eq(&col_sums, &[3.0, 7.0]);
assert_vec_f64_eq(&row_sums, &[4.0, 6.0]);
#[should_panic(expected = "column index out of bounds")]
fn test_add_column_out_of_bounds() {
let mut matrix = create_test_matrix_2x4(); // 2x4
let new_col = vec![9, 10];
matrix.add_column(5, new_col); // Index 5 is out of bounds for 4 columns
}
#[test]
fn is_nan_mask() {
let m = Matrix::from_cols(vec![vec![1.0, NAN]]);
let expected = Matrix::from_cols(vec![vec![false, true]]);
assert_eq!(m.is_nan(), expected);
#[should_panic(expected = "column length mismatch")]
fn test_add_column_length_mismatch() {
let mut matrix = create_test_matrix_2x4(); // 2x4 (2 rows)
let new_col = vec![9, 10, 11]; // Wrong length
matrix.add_column(0, new_col);
}
#[test]
fn test_add_row() {
let mut matrix = create_test_matrix_2x4(); // 2x4
// 1 3 5 7
// 2 4 6 8
let new_row = vec![9, 10, 11, 12];
matrix.add_row(1, new_row); // Add at index 1
// Should be:
// 1 3 5 7
// 9 10 11 12
// 2 4 6 8
assert_eq!(matrix.rows(), 3);
assert_eq!(matrix.cols(), 4);
assert_eq!(matrix[(0, 0)], 1);
assert_eq!(matrix[(0, 1)], 3);
assert_eq!(matrix[(0, 2)], 5);
assert_eq!(matrix[(0, 3)], 7);
assert_eq!(matrix[(1, 0)], 9);
assert_eq!(matrix[(1, 1)], 10);
assert_eq!(matrix[(1, 2)], 11);
assert_eq!(matrix[(1, 3)], 12);
assert_eq!(matrix[(2, 0)], 2);
assert_eq!(matrix[(2, 1)], 4);
assert_eq!(matrix[(2, 2)], 6);
assert_eq!(matrix[(2, 3)], 8);
// Check underlying data (column-major)
// Original: [1, 2, 3, 4, 5, 6, 7, 8] (rows 0, 1)
// Add [9, 10, 11, 12] at index 1 (new row will be index 1, original row 1 becomes index 2)
// Col 0: [1, 9, 2]
// Col 1: [3, 10, 4]
// Col 2: [5, 11, 6]
// Col 3: [7, 12, 8]
// Data: [1, 9, 2, 3, 10, 4, 5, 11, 6, 7, 12, 8]
assert_eq!(matrix.data(), &[1, 9, 2, 3, 10, 4, 5, 11, 6, 7, 12, 8]);
// Add a row at the beginning
let new_row_start = vec![13, 14, 15, 16];
matrix.add_row(0, new_row_start);
// Should be:
// 13 14 15 16
// 1 3 5 7
// 9 10 11 12
// 2 4 6 8
assert_eq!(matrix.rows(), 4);
assert_eq!(matrix.cols(), 4);
assert_eq!(matrix[(0, 0)], 13);
assert_eq!(matrix[(0, 3)], 16);
// Check some existing elements to ensure they shifted correctly
assert_eq!(matrix[(1, 0)], 1);
assert_eq!(matrix[(2, 1)], 10);
assert_eq!(matrix[(3, 3)], 8);
// Add a row at the end
let new_row_end = vec![17, 18, 19, 20];
matrix.add_row(4, new_row_end);
// Should be:
// 13 14 15 16
// 1 3 5 7
// 9 10 11 12
// 2 4 6 8
// 17 18 19 20
assert_eq!(matrix.rows(), 5);
assert_eq!(matrix.cols(), 4);
assert_eq!(matrix[(4, 0)], 17);
assert_eq!(matrix[(4, 3)], 20);
}
#[test]
#[should_panic(expected = "row index out of bounds")]
fn test_add_row_out_of_bounds() {
let mut matrix = create_test_matrix_2x4(); // 2x4
let new_row = vec![9, 10, 11, 12];
matrix.add_row(3, new_row); // Index 3 is out of bounds for 2 rows
}
#[test]
#[should_panic(expected = "row length mismatch")]
fn test_add_row_length_mismatch() {
let mut matrix = create_test_matrix_2x4(); // 2x4 (4 cols)
let new_row = vec![9, 10, 11]; // Wrong length
matrix.add_row(0, new_row);
}
#[test]
fn test_elementwise_add() {
let matrix1 = create_test_matrix(); // 3x3
let matrix2 = Matrix::from_vec(vec![9, 8, 7, 6, 5, 4, 3, 2, 1], 3, 3); // 3x3
let result = &matrix1 + &matrix2;
assert_eq!(result.rows(), 3);
assert_eq!(result.cols(), 3);
// Expected:
// 1+9 4+6 7+3 => 10 10 10
// 2+8 5+5 8+2 => 10 10 10
// 3+7 6+4 9+1 => 10 10 10
// Column-major data: [10, 10, 10, 10, 10, 10, 10, 10, 10]
assert_eq!(result.data(), &[10, 10, 10, 10, 10, 10, 10, 10, 10]);
assert_eq!(result[(0, 0)], 10);
assert_eq!(result[(1, 1)], 10);
assert_eq!(result[(2, 2)], 10);
}
#[test]
fn test_elementwise_sub() {
let matrix1 = create_test_matrix(); // 3x3
let matrix2 = Matrix::from_vec(vec![1, 1, 1, 2, 2, 2, 3, 3, 3], 3, 3); // 3x3
let result = &matrix1 - &matrix2;
assert_eq!(result.rows(), 3);
assert_eq!(result.cols(), 3);
// Expected:
// 1-1 4-2 7-3 => 0 2 4
// 2-1 5-2 8-3 => 1 3 5
// 3-1 6-2 9-3 => 2 4 6
// Column-major data: [0, 1, 2, 2, 3, 4, 4, 5, 6]
assert_eq!(result.data(), &[0, 1, 2, 2, 3, 4, 4, 5, 6]);
assert_eq!(result[(0, 0)], 0);
assert_eq!(result[(1, 1)], 3);
assert_eq!(result[(2, 2)], 6);
}
#[test]
fn test_elementwise_mul() {
let matrix1 = create_test_matrix(); // 3x3
let matrix2 = Matrix::from_vec(vec![1, 2, 3, 1, 2, 3, 1, 2, 3], 3, 3); // 3x3
let result = &matrix1 * &matrix2;
assert_eq!(result.rows(), 3);
assert_eq!(result.cols(), 3);
// Expected:
// 1*1 4*1 7*1 => 1 4 7
// 2*2 5*2 8*2 => 4 10 16
// 3*3 6*3 9*3 => 9 18 27
// Column-major data: [1, 4, 9, 4, 10, 18, 7, 16, 27]
assert_eq!(result.data(), &[1, 4, 9, 4, 10, 18, 7, 16, 27]);
assert_eq!(result[(0, 0)], 1);
assert_eq!(result[(1, 1)], 10);
assert_eq!(result[(2, 2)], 27);
}
#[test]
fn test_elementwise_div() {
let matrix1 = create_test_matrix(); // 3x3
let matrix2 = Matrix::from_vec(vec![1, 1, 1, 2, 2, 2, 7, 8, 9], 3, 3); // 3x3
let result = &matrix1 / &matrix2; // Integer division
assert_eq!(result.rows(), 3);
assert_eq!(result.cols(), 3);
// Expected:
// 1/1 4/2 7/7 => 1 2 1
// 2/1 5/2 8/8 => 2 2 1 (integer division)
// 3/1 6/2 9/9 => 3 3 1
// Column-major data: [1, 2, 3, 2, 2, 3, 1, 1, 1]
assert_eq!(result.data(), &[1, 2, 3, 2, 2, 3, 1, 1, 1]);
assert_eq!(result[(0, 0)], 1);
assert_eq!(result[(1, 1)], 2);
assert_eq!(result[(2, 2)], 1);
}
#[test]
#[should_panic(expected = "row count mismatch")]
fn test_elementwise_op_row_mismatch() {
let matrix1 = create_test_matrix(); // 3x3
let matrix2 = create_test_matrix_2x4(); // 2x4
let _ = &matrix1 + &matrix2; // Should panic
}
#[test]
#[should_panic(expected = "row count mismatch")]
fn test_elementwise_op_col_mismatch() {
let matrix1 = create_test_matrix(); // 3x3
let matrix2 = create_test_matrix_2x4(); // 2x4
let _ = &matrix1 * &matrix2; // Should panic
}
#[test]
fn test_bitwise_and() {
let data1 = vec![true, false, true, false, true, false]; // 2x3
let data2 = vec![true, true, false, false, true, true]; // 2x3
let matrix1 = BoolMatrix::from_vec(data1, 2, 3);
let matrix2 = BoolMatrix::from_vec(data2, 2, 3);
// Expected column-major results:
// T & T = T
// F & T = F
// T & F = F
// F & F = F
// T & T = T
// F & T = F
// Data: [T, F, F, F, T, F]
let expected_data = vec![true, false, false, false, true, false];
let expected_matrix = BoolMatrix::from_vec(expected_data, 2, 3);
let result = &matrix1 & &matrix2;
assert_eq!(result, expected_matrix);
}
#[test]
fn test_bitwise_or() {
let data1 = vec![true, false, true, false, true, false]; // 2x3
let data2 = vec![true, true, false, false, true, true]; // 2x3
let matrix1 = BoolMatrix::from_vec(data1, 2, 3);
let matrix2 = BoolMatrix::from_vec(data2, 2, 3);
// Expected column-major results:
// T | T = T
// F | T = T
// T | F = T
// F | F = F
// T | T = T
// F | T = T
// Data: [T, T, T, F, T, T]
let expected_data = vec![true, true, true, false, true, true];
let expected_matrix = BoolMatrix::from_vec(expected_data, 2, 3);
let result = &matrix1 | &matrix2;
assert_eq!(result, expected_matrix);
}
#[test]
fn test_bitwise_xor() {
let data1 = vec![true, false, true, false, true, false]; // 2x3
let data2 = vec![true, true, false, false, true, true]; // 2x3
let matrix1 = BoolMatrix::from_vec(data1, 2, 3);
let matrix2 = BoolMatrix::from_vec(data2, 2, 3);
// Expected column-major results:
// T ^ T = F
// F ^ T = T
// T ^ F = T
// F ^ F = F
// T ^ T = F
// F ^ T = T
// Data: [F, T, T, F, F, T]
let expected_data = vec![false, true, true, false, false, true];
let expected_matrix = BoolMatrix::from_vec(expected_data, 2, 3);
let result = &matrix1 ^ &matrix2;
assert_eq!(result, expected_matrix);
}
#[test]
fn test_bitwise_not() {
let data = vec![true, false, true, false, true, false]; // 2x3
let matrix = BoolMatrix::from_vec(data, 2, 3);
// Expected column-major results:
// !T = F
// !F = T
// !T = F
// !F = T
// !T = F
// !F = T
// Data: [F, T, F, T, F, T]
let expected_data = vec![false, true, false, true, false, true];
let expected_matrix = BoolMatrix::from_vec(expected_data, 2, 3);
let result = !matrix; // Not consumes the matrix
assert_eq!(result, expected_matrix);
}
#[test]
#[should_panic(expected = "col count mismatch")]
fn test_bitwise_op_row_mismatch() {
let data1 = vec![true, false, true, false]; // 2x2
let data2 = vec![true, true, false, false, true, true]; // 2x3
let matrix1 = BoolMatrix::from_vec(data1, 2, 2);
let matrix2 = BoolMatrix::from_vec(data2, 2, 3);
let _ = &matrix1 & &matrix2; // Should panic
}
#[test]
#[should_panic(expected = "col count mismatch")]
fn test_bitwise_op_col_mismatch() {
let data1 = vec![true, false, true, false]; // 2x2
let data2 = vec![true, true, false, false, true, true]; // 2x3
let matrix1 = BoolMatrix::from_vec(data1, 2, 2);
let matrix2 = BoolMatrix::from_vec(data2, 2, 3);
let _ = &matrix1 | &matrix2; // Should panic
}
// Test with String type (requires Clone, PartialEq)
#[test]
fn test_string_matrix() {
let data = vec![
"a".to_string(),
"b".to_string(),
"c".to_string(),
"d".to_string(),
];
let matrix = StringMatrix::from_vec(data.clone(), 2, 2); // 2x2
assert_eq!(matrix[(0, 0)], "a".to_string());
assert_eq!(matrix[(1, 0)], "b".to_string());
assert_eq!(matrix[(0, 1)], "c".to_string());
assert_eq!(matrix[(1, 1)], "d".to_string());
// Test modification
let mut matrix = matrix;
matrix[(0, 0)] = "hello".to_string();
assert_eq!(matrix[(0, 0)], "hello".to_string());
// Test add_column (requires Clone)
let new_col = vec!["e".to_string(), "f".to_string()];
matrix.add_column(1, new_col); // Add at index 1
// Should be:
// hello c d
// b e f
assert_eq!(matrix.rows(), 2);
assert_eq!(matrix.cols(), 3);
assert_eq!(matrix[(0, 0)], "hello".to_string());
assert_eq!(matrix[(1, 0)], "b".to_string());
assert_eq!(matrix[(0, 1)], "e".to_string()); // New col
assert_eq!(matrix[(1, 1)], "f".to_string()); // New col
assert_eq!(matrix[(0, 2)], "c".to_string()); // Shifted
assert_eq!(matrix[(1, 2)], "d".to_string()); // Shifted
// Test add_row (requires Clone)
let new_row = vec!["g".to_string(), "h".to_string(), "i".to_string()];
matrix.add_row(0, new_row); // Add at index 0
// Should be:
// g h i
// hello e c
// b f d
assert_eq!(matrix.rows(), 3);
assert_eq!(matrix.cols(), 3);
assert_eq!(matrix[(0, 0)], "g".to_string());
assert_eq!(matrix[(0, 1)], "h".to_string());
assert_eq!(matrix[(0, 2)], "i".to_string());
assert_eq!(matrix[(1, 0)], "hello".to_string()); // Shifted
assert_eq!(matrix[(2, 2)], "d".to_string()); // Shifted
}
#[test]
fn test_float_matrix_ops() {
let data1 = vec![1.0, 2.0, 3.0, 4.0]; // 2x2
let data2 = vec![0.5, 1.5, 2.5, 3.5]; // 2x2
let matrix1 = FloatMatrix::from_vec(data1, 2, 2);
let matrix2 = FloatMatrix::from_vec(data2, 2, 2);
let sum = &matrix1 + &matrix2;
let diff = &matrix1 - &matrix2;
let prod = &matrix1 * &matrix2;
let div = &matrix1 / &matrix2;
// Check sums (col-major): [1.5, 3.5, 5.5, 7.5]
assert_eq!(sum.data(), &[1.5, 3.5, 5.5, 7.5]);
// Check diffs (col-major): [0.5, 0.5, 0.5, 0.5]
assert_eq!(diff.data(), &[0.5, 0.5, 0.5, 0.5]);
// Check prods (col-major): [0.5, 3.0, 7.5, 14.0]
assert_eq!(prod.data(), &[0.5, 3.0, 7.5, 14.0]);
// Check divs (col-major): [2.0, 1.333..., 1.2, 1.14...]
// Using element access for more specific checks on floating point results
assert_eq!(div.rows(), 2);
assert_eq!(div.cols(), 2);
assert!((div[(0, 0)] - 1.0 / 0.5).abs() < 1e-9); // 2.0
assert!((div[(1, 0)] - 2.0 / 1.5).abs() < 1e-9); // 1.333...
assert!((div[(0, 1)] - 3.0 / 2.5).abs() < 1e-9); // 1.2
assert!((div[(1, 1)] - 4.0 / 3.5).abs() < 1e-9); // 1.14...
}
// Axis enum doesn't have logic, no tests needed directly, but its presence is verified by compilation.
}