Refactor assertions for clarity in Matrix methods and update test cases for consistency

This commit is contained in:
Palash Tyagi 2025-07-05 21:07:32 +01:00
parent ffe635f1c4
commit 1ebc3546d2

View File

@ -168,7 +168,12 @@ impl<T: Clone> Matrix<T> {
/// Deletes a column from the matrix. Panics on out-of-bounds. /// Deletes a column from the matrix. Panics on out-of-bounds.
/// This is O(N) where N is the number of elements. /// This is O(N) where N is the number of elements.
pub fn delete_column(&mut self, col: usize) { pub fn delete_column(&mut self, col: usize) {
assert!(col < self.cols, "column index {} out of bounds for {} columns", col, self.cols); assert!(
col < self.cols,
"column index {} out of bounds for {} columns",
col,
self.cols
);
let start = col * self.rows; let start = col * self.rows;
self.data.drain(start..start + self.rows); // Efficient removal self.data.drain(start..start + self.rows); // Efficient removal
self.cols -= 1; self.cols -= 1;
@ -254,8 +259,19 @@ impl<T: Clone> Matrix<T> {
/// Adds a row to the matrix at the specified index. Panics if index > rows or length mismatch. /// Adds a row to the matrix at the specified index. Panics if index > rows or length mismatch.
/// This is O(N) where N is the number of elements, as it rebuilds the data vec. /// This is O(N) where N is the number of elements, as it rebuilds the data vec.
pub fn add_row(&mut self, index: usize, row: Vec<T>) { pub fn add_row(&mut self, index: usize, row: Vec<T>) {
assert!(index <= self.rows, "add_row index {} out of bounds for {} rows", index, self.rows); assert!(
assert_eq!(row.len(), self.cols, "row length mismatch: expected {} (cols), got {}", self.cols, row.len()); index <= self.rows,
"add_row index {} out of bounds for {} rows",
index,
self.rows
);
assert_eq!(
row.len(),
self.cols,
"row length mismatch: expected {} (cols), got {}",
self.cols,
row.len()
);
if self.cols == 0 && self.rows == 0 { if self.cols == 0 && self.rows == 0 {
// Special case: adding first row to empty matrix // Special case: adding first row to empty matrix
@ -640,7 +656,6 @@ pub type BoolMatrix = Matrix<bool>;
pub type IntMatrix = Matrix<i32>; pub type IntMatrix = Matrix<i32>;
pub type StringMatrix = Matrix<String>; pub type StringMatrix = Matrix<String>;
// --- Unit Tests --- // --- Unit Tests ---
#[cfg(test)] #[cfg(test)]
@ -1049,20 +1064,20 @@ mod tests {
#[test] #[test]
#[should_panic(expected = "index out of bounds")] #[should_panic(expected = "index out of bounds")]
fn test_index_out_of_bounds_row_alt() { fn test_index_out_of_bounds_row_alt() {
let matrix = static_test_matrix(); // 3x3 let matrix = static_test_matrix();
let _ = matrix[(3, 0)]; let _ = matrix[(3, 0)];
} }
#[test] #[test]
#[should_panic(expected = "index out of bounds")] #[should_panic(expected = "index out of bounds")]
fn test_index_out_of_bounds_col_alt() { fn test_index_out_of_bounds_col_alt() {
let matrix = static_test_matrix(); // 3x3 let matrix = static_test_matrix();
let _ = matrix[(0, 3)]; let _ = matrix[(0, 3)];
} }
#[test] #[test]
fn test_index_mut_and_get_mut() { fn test_index_mut_and_get_mut() {
let mut matrix = static_test_matrix(); // 3x3 let mut matrix = static_test_matrix();
matrix[(0, 0)] = 10; matrix[(0, 0)] = 10;
matrix[(1, 1)] = 20; matrix[(1, 1)] = 20;
@ -1078,33 +1093,26 @@ mod tests {
assert_eq!(matrix[(0, 1)], 15); assert_eq!(matrix[(0, 1)], 15);
assert_eq!(matrix[(2, 1)], 25); assert_eq!(matrix[(2, 1)], 25);
// Check underlying data consistency (column-major)
// Should be:
// 10 15 7
// 2 20 8
// 3 25 30
assert_eq!(matrix.data(), &[10, 2, 3, 15, 20, 25, 7, 8, 30]); assert_eq!(matrix.data(), &[10, 2, 3, 15, 20, 25, 7, 8, 30]);
} }
#[test] #[test]
#[should_panic(expected = "index out of bounds")] #[should_panic(expected = "index out of bounds")]
fn test_index_mut_out_of_bounds_row() { fn test_index_mut_out_of_bounds_row() {
let mut matrix = static_test_matrix(); // 3x3 let mut matrix = static_test_matrix();
matrix[(3, 0)] = 99; matrix[(3, 0)] = 99;
} }
#[test] #[test]
#[should_panic(expected = "index out of bounds")] #[should_panic(expected = "index out of bounds")]
fn test_index_mut_out_of_bounds_col() { fn test_index_mut_out_of_bounds_col() {
let mut matrix = static_test_matrix(); // 3x3 let mut matrix = static_test_matrix();
matrix[(0, 3)] = 99; matrix[(0, 3)] = 99;
} }
#[test] #[test]
fn test_column() { fn test_column() {
let matrix = static_test_matrix_2x4(); // 2x4 let matrix = static_test_matrix_2x4();
// 1 3 5 7
// 2 4 6 8
assert_eq!(matrix.column(0), &[1, 2]); assert_eq!(matrix.column(0), &[1, 2]);
assert_eq!(matrix.column(1), &[3, 4]); assert_eq!(matrix.column(1), &[3, 4]);
@ -1115,15 +1123,13 @@ mod tests {
#[test] #[test]
#[should_panic(expected = "column index 4 out of bounds for 4 columns")] #[should_panic(expected = "column index 4 out of bounds for 4 columns")]
fn test_column_out_of_bounds() { fn test_column_out_of_bounds() {
let matrix = static_test_matrix_2x4(); // 2x4 let matrix = static_test_matrix_2x4();
matrix.column(4); matrix.column(4);
} }
#[test] #[test]
fn test_column_mut() { fn test_column_mut() {
let mut matrix = static_test_matrix_2x4(); // 2x4 let mut matrix = static_test_matrix_2x4();
// 1 3 5 7
// 2 4 6 8
let col1_mut = matrix.column_mut(1); let col1_mut = matrix.column_mut(1);
col1_mut[0] = 30; col1_mut[0] = 30;
@ -1138,25 +1144,19 @@ mod tests {
assert_eq!(matrix[(0, 3)], 70); assert_eq!(matrix[(0, 3)], 70);
assert_eq!(matrix[(1, 3)], 8); // Unchanged assert_eq!(matrix[(1, 3)], 8); // Unchanged
// Check underlying data (column-major)
// Should be:
// 1 30 5 70
// 2 40 6 8
assert_eq!(matrix.data(), &[1, 2, 30, 40, 5, 6, 70, 8]); assert_eq!(matrix.data(), &[1, 2, 30, 40, 5, 6, 70, 8]);
} }
#[test] #[test]
#[should_panic(expected = "column index 4 out of bounds for 4 columns")] #[should_panic(expected = "column index 4 out of bounds for 4 columns")]
fn test_column_mut_out_of_bounds() { fn test_column_mut_out_of_bounds() {
let mut matrix = static_test_matrix_2x4(); // 2x4 let mut matrix = static_test_matrix_2x4();
matrix.column_mut(4); matrix.column_mut(4);
} }
#[test] #[test]
fn test_iter_columns() { fn test_iter_columns() {
let matrix = static_test_matrix_2x4(); // 2x4 let matrix = static_test_matrix_2x4();
// 1 3 5 7
// 2 4 6 8
let cols: Vec<&[i32]> = matrix.iter_columns().collect(); let cols: Vec<&[i32]> = matrix.iter_columns().collect();
assert_eq!(cols.len(), 4); assert_eq!(cols.len(), 4);
@ -1168,9 +1168,7 @@ mod tests {
#[test] #[test]
fn test_iter_rows() { fn test_iter_rows() {
let matrix = static_test_matrix_2x4(); // 2x4 let matrix = static_test_matrix_2x4();
// 1 3 5 7
// 2 4 6 8
let rows: Vec<Vec<i32>> = matrix let rows: Vec<Vec<i32>> = matrix
.iter_rows() .iter_rows()
@ -1184,10 +1182,7 @@ mod tests {
// test data_mut // test data_mut
#[test] #[test]
fn test_data_mut() { fn test_data_mut() {
let mut matrix = static_test_matrix(); // 3x3 let mut matrix = static_test_matrix();
// 1 4 7
// 2 5 8
// 3 6 9
let data_mut = matrix.data_mut(); let data_mut = matrix.data_mut();
data_mut[0] = 10; data_mut[0] = 10;
@ -1199,10 +1194,7 @@ mod tests {
#[test] #[test]
fn test_matrix_row_get_and_iter() { fn test_matrix_row_get_and_iter() {
let matrix = static_test_matrix_2x4(); // 2x4 let matrix = static_test_matrix_2x4();
// 1 3 5 7
// 2 4 6 8
let row0 = matrix.iter_rows().next().unwrap(); let row0 = matrix.iter_rows().next().unwrap();
assert_eq!(*row0.get(0), 1); assert_eq!(*row0.get(0), 1);
assert_eq!(*row0.get(1), 3); assert_eq!(*row0.get(1), 3);
@ -1219,17 +1211,9 @@ mod tests {
#[test] #[test]
fn test_swap_columns() { fn test_swap_columns() {
let mut matrix = static_test_matrix(); // 3x3 let mut matrix = static_test_matrix();
// 1 4 7
// 2 5 8
// 3 6 9
matrix.swap_columns(0, 2); // Swap first and last matrix.swap_columns(0, 2); // swap first and last
// Should be:
// 7 4 1
// 8 5 2
// 9 6 3
assert_eq!(matrix.rows(), 3); assert_eq!(matrix.rows(), 3);
assert_eq!(matrix.cols(), 3); assert_eq!(matrix.cols(), 3);
@ -1243,7 +1227,7 @@ mod tests {
assert_eq!(matrix[(1, 2)], 2); assert_eq!(matrix[(1, 2)], 2);
assert_eq!(matrix[(2, 2)], 3); assert_eq!(matrix[(2, 2)], 3);
// Swap the same column (should do nothing) // swap the same column (should do nothing)
let original_data = matrix.data().to_vec(); let original_data = matrix.data().to_vec();
matrix.swap_columns(1, 1); matrix.swap_columns(1, 1);
assert_eq!(matrix.data(), &original_data); // Data should be identical assert_eq!(matrix.data(), &original_data); // Data should be identical
@ -1255,22 +1239,15 @@ mod tests {
#[test] #[test]
#[should_panic(expected = "column index c2=3 out of bounds for 3 columns")] #[should_panic(expected = "column index c2=3 out of bounds for 3 columns")]
fn test_swap_columns_out_of_bounds() { fn test_swap_columns_out_of_bounds() {
let mut matrix = static_test_matrix(); // 3x3 let mut matrix = static_test_matrix();
matrix.swap_columns(0, 3); matrix.swap_columns(0, 3);
} }
#[test] #[test]
fn test_delete_column() { fn test_delete_column() {
let mut matrix = static_test_matrix_2x4(); // 2x4 let mut matrix = static_test_matrix_2x4();
// 1 3 5 7
// 2 4 6 8
matrix.delete_column(1); // Delete the second column (index 1) matrix.delete_column(1); // Delete the second column (index 1)
// Should be:
// 1 5 7
// 2 6 8
assert_eq!(matrix.rows(), 2); assert_eq!(matrix.rows(), 2);
assert_eq!(matrix.cols(), 3); assert_eq!(matrix.cols(), 3);
assert_eq!(matrix[(0, 0)], 1); assert_eq!(matrix[(0, 0)], 1);
@ -1280,23 +1257,17 @@ mod tests {
assert_eq!(matrix[(0, 2)], 7); assert_eq!(matrix[(0, 2)], 7);
assert_eq!(matrix[(1, 2)], 8); assert_eq!(matrix[(1, 2)], 8);
// Check underlying data (column-major) // check underlying data
assert_eq!(matrix.data(), &[1, 2, 5, 6, 7, 8]); assert_eq!(matrix.data(), &[1, 2, 5, 6, 7, 8]);
// Delete the first column // Delete the first column
matrix.delete_column(0); matrix.delete_column(0);
// Should be:
// 5 7
// 6 8
assert_eq!(matrix.rows(), 2); assert_eq!(matrix.rows(), 2);
assert_eq!(matrix.cols(), 2); assert_eq!(matrix.cols(), 2);
assert_eq!(matrix.data(), &[5, 6, 7, 8]); assert_eq!(matrix.data(), &[5, 6, 7, 8]);
// Delete the last column // Delete the last column
matrix.delete_column(1); matrix.delete_column(1);
// Should be:
// 5
// 6
assert_eq!(matrix.rows(), 2); assert_eq!(matrix.rows(), 2);
assert_eq!(matrix.cols(), 1); assert_eq!(matrix.cols(), 1);
assert_eq!(matrix.data(), &[5, 6]); assert_eq!(matrix.data(), &[5, 6]);
@ -1312,22 +1283,15 @@ mod tests {
#[test] #[test]
#[should_panic(expected = "column index 4 out of bounds for 4 columns")] #[should_panic(expected = "column index 4 out of bounds for 4 columns")]
fn test_delete_column_out_of_bounds() { fn test_delete_column_out_of_bounds() {
let mut matrix = static_test_matrix_2x4(); // 2x4 let mut matrix = static_test_matrix_2x4();
matrix.delete_column(4); matrix.delete_column(4);
} }
#[test] #[test]
fn test_delete_row() { fn test_delete_row() {
let mut matrix = static_test_matrix(); // 3x3 let mut matrix = static_test_matrix();
// 1 4 7
// 2 5 8
// 3 6 9
matrix.delete_row(1); // Delete the second row (index 1) matrix.delete_row(1); // Delete the second row
// Should be:
// 1 4 7
// 3 6 9
assert_eq!(matrix.rows(), 2); assert_eq!(matrix.rows(), 2);
assert_eq!(matrix.cols(), 3); assert_eq!(matrix.cols(), 3);
@ -1338,9 +1302,7 @@ mod tests {
assert_eq!(matrix[(0, 2)], 7); assert_eq!(matrix[(0, 2)], 7);
assert_eq!(matrix[(1, 2)], 9); assert_eq!(matrix[(1, 2)], 9);
// Check underlying data (column-major) // check underlying data (column-major)
// Original: [1, 2, 3, 4, 5, 6, 7, 8, 9]
// Delete row 1: [1, 3, 4, 6, 7, 9]
assert_eq!(matrix.data(), &[1, 3, 4, 6, 7, 9]); assert_eq!(matrix.data(), &[1, 3, 4, 6, 7, 9]);
// Delete the first row // Delete the first row
@ -1362,22 +1324,15 @@ mod tests {
#[test] #[test]
#[should_panic(expected = "row index 3 out of bounds for 3 rows")] #[should_panic(expected = "row index 3 out of bounds for 3 rows")]
fn test_delete_row_out_of_bounds() { fn test_delete_row_out_of_bounds() {
let mut matrix = static_test_matrix(); // 3x3 let mut matrix = static_test_matrix();
matrix.delete_row(3); matrix.delete_row(3);
} }
#[test] #[test]
fn test_add_column() { fn test_add_column() {
let mut matrix = static_test_matrix_2x4(); // 2x4 let mut matrix = static_test_matrix_2x4();
// 1 3 5 7
// 2 4 6 8
let new_col = vec![9, 10]; let new_col = vec![9, 10];
matrix.add_column(2, new_col); // Add at index 2 matrix.add_column(2, new_col);
// Should be:
// 1 3 9 5 7
// 2 4 10 6 8
assert_eq!(matrix.rows(), 2); assert_eq!(matrix.rows(), 2);
assert_eq!(matrix.cols(), 5); assert_eq!(matrix.cols(), 5);
@ -1392,17 +1347,12 @@ mod tests {
assert_eq!(matrix[(0, 4)], 7); // Shifted assert_eq!(matrix[(0, 4)], 7); // Shifted
assert_eq!(matrix[(1, 4)], 8); assert_eq!(matrix[(1, 4)], 8);
// Check underlying data (column-major) // Check underlying data
// Original: [1, 2, 3, 4, 5, 6, 7, 8]
// Add [9, 10] at index 2: [1, 2, 3, 4, 9, 10, 5, 6, 7, 8]
assert_eq!(matrix.data(), &[1, 2, 3, 4, 9, 10, 5, 6, 7, 8]); assert_eq!(matrix.data(), &[1, 2, 3, 4, 9, 10, 5, 6, 7, 8]);
// Add a column at the beginning // Add a column at the beginning
let new_col_start = vec![11, 12]; let new_col_start = vec![11, 12];
matrix.add_column(0, new_col_start); matrix.add_column(0, new_col_start);
// Should be:
// 11 1 3 9 5 7
// 12 2 4 10 6 8
assert_eq!(matrix.rows(), 2); assert_eq!(matrix.rows(), 2);
assert_eq!(matrix.cols(), 6); assert_eq!(matrix.cols(), 6);
assert_eq!(matrix[(0, 0)], 11); assert_eq!(matrix[(0, 0)], 11);
@ -1412,9 +1362,6 @@ mod tests {
// Add a column at the end // Add a column at the end
let new_col_end = vec![13, 14]; let new_col_end = vec![13, 14];
matrix.add_column(6, new_col_end); matrix.add_column(6, new_col_end);
// Should be:
// 11 1 3 9 5 7 13
// 12 2 4 10 6 8 14
assert_eq!(matrix.rows(), 2); assert_eq!(matrix.rows(), 2);
assert_eq!(matrix.cols(), 7); assert_eq!(matrix.cols(), 7);
assert_eq!(matrix[(0, 6)], 13); assert_eq!(matrix[(0, 6)], 13);
@ -1428,7 +1375,7 @@ mod tests {
#[test] #[test]
#[should_panic(expected = "add_column index 5 out of bounds for 4 columns")] #[should_panic(expected = "add_column index 5 out of bounds for 4 columns")]
fn test_add_column_out_of_bounds() { fn test_add_column_out_of_bounds() {
let mut matrix = static_test_matrix_2x4(); // 2x4 let mut matrix = static_test_matrix_2x4();
let new_col = vec![9, 10]; let new_col = vec![9, 10];
matrix.add_column(5, new_col); // Index 5 is out of bounds for 4 columns matrix.add_column(5, new_col); // Index 5 is out of bounds for 4 columns
} }
@ -1436,24 +1383,16 @@ mod tests {
#[test] #[test]
#[should_panic(expected = "column length mismatch")] #[should_panic(expected = "column length mismatch")]
fn test_add_column_length_mismatch() { fn test_add_column_length_mismatch() {
let mut matrix = static_test_matrix_2x4(); // 2x4 (2 rows) let mut matrix = static_test_matrix_2x4();
let new_col = vec![9, 10, 11]; // Wrong length let new_col = vec![9, 10, 11]; // Wrong length
matrix.add_column(0, new_col); matrix.add_column(0, new_col);
} }
#[test] #[test]
fn test_add_row() { fn test_add_row() {
let mut matrix = static_test_matrix_2x4(); // 2x4 let mut matrix = static_test_matrix_2x4();
// 1 3 5 7
// 2 4 6 8
let new_row = vec![9, 10, 11, 12]; let new_row = vec![9, 10, 11, 12];
matrix.add_row(1, new_row); // Add at index 1 matrix.add_row(1, new_row);
// Should be:
// 1 3 5 7
// 9 10 11 12
// 2 4 6 8
assert_eq!(matrix.rows(), 3); assert_eq!(matrix.rows(), 3);
assert_eq!(matrix.cols(), 4); assert_eq!(matrix.cols(), 4);
@ -1472,23 +1411,11 @@ mod tests {
assert_eq!(matrix[(2, 3)], 8); assert_eq!(matrix[(2, 3)], 8);
// Check underlying data (column-major) // Check underlying data (column-major)
// Original: [1, 2, 3, 4, 5, 6, 7, 8] (rows 0, 1)
// Add [9, 10, 11, 12] at index 1 (new row will be index 1, original row 1 becomes index 2)
// Col 0: [1, 9, 2]
// Col 1: [3, 10, 4]
// Col 2: [5, 11, 6]
// Col 3: [7, 12, 8]
// Data: [1, 9, 2, 3, 10, 4, 5, 11, 6, 7, 12, 8]
assert_eq!(matrix.data(), &[1, 9, 2, 3, 10, 4, 5, 11, 6, 7, 12, 8]); assert_eq!(matrix.data(), &[1, 9, 2, 3, 10, 4, 5, 11, 6, 7, 12, 8]);
// Add a row at the beginning // Add a row at the beginning
let new_row_start = vec![13, 14, 15, 16]; let new_row_start = vec![13, 14, 15, 16];
matrix.add_row(0, new_row_start); matrix.add_row(0, new_row_start);
// Should be:
// 13 14 15 16
// 1 3 5 7
// 9 10 11 12
// 2 4 6 8
assert_eq!(matrix.rows(), 4); assert_eq!(matrix.rows(), 4);
assert_eq!(matrix.cols(), 4); assert_eq!(matrix.cols(), 4);
assert_eq!(matrix[(0, 0)], 13); assert_eq!(matrix[(0, 0)], 13);
@ -1502,12 +1429,6 @@ mod tests {
// Add a row at the end // Add a row at the end
let new_row_end = vec![17, 18, 19, 20]; let new_row_end = vec![17, 18, 19, 20];
matrix.add_row(4, new_row_end); matrix.add_row(4, new_row_end);
// Should be:
// 13 14 15 16
// 1 3 5 7
// 9 10 11 12
// 2 4 6 8
// 17 18 19 20
assert_eq!(matrix.rows(), 5); assert_eq!(matrix.rows(), 5);
assert_eq!(matrix.cols(), 4); assert_eq!(matrix.cols(), 4);
assert_eq!(matrix[(4, 0)], 17); assert_eq!(matrix[(4, 0)], 17);
@ -1517,7 +1438,7 @@ mod tests {
#[test] #[test]
#[should_panic(expected = "add_row index 3 out of bounds for 2 rows")] #[should_panic(expected = "add_row index 3 out of bounds for 2 rows")]
fn test_add_row_out_of_bounds() { fn test_add_row_out_of_bounds() {
let mut matrix = static_test_matrix_2x4(); // 2x4 let mut matrix = static_test_matrix_2x4();
let new_row = vec![9, 10, 11, 12]; let new_row = vec![9, 10, 11, 12];
matrix.add_row(3, new_row); // Index 3 is out of bounds for 2 rows matrix.add_row(3, new_row); // Index 3 is out of bounds for 2 rows
} }
@ -1525,26 +1446,22 @@ mod tests {
#[test] #[test]
#[should_panic(expected = "row length mismatch")] #[should_panic(expected = "row length mismatch")]
fn test_add_row_length_mismatch() { fn test_add_row_length_mismatch() {
let mut matrix = static_test_matrix_2x4(); // 2x4 (4 cols) let mut matrix = static_test_matrix_2x4();
let new_row = vec![9, 10, 11]; // Wrong length let new_row = vec![9, 10, 11]; // Wrong length
matrix.add_row(0, new_row); matrix.add_row(0, new_row);
} }
#[test] #[test]
fn test_elementwise_add() { fn test_elementwise_add() {
let matrix1 = static_test_matrix(); // 3x3 let matrix1 = static_test_matrix();
let matrix2 = Matrix::from_vec(vec![9, 8, 7, 6, 5, 4, 3, 2, 1], 3, 3); // 3x3 let matrix2 = Matrix::from_vec(vec![9, 8, 7, 6, 5, 4, 3, 2, 1], 3, 3);
let result = &matrix1 + &matrix2; let result = &matrix1 + &matrix2;
assert_eq!(result.rows(), 3); assert_eq!(result.rows(), 3);
assert_eq!(result.cols(), 3); assert_eq!(result.cols(), 3);
// Expected: // Expect all 10s
// 1+9 4+6 7+3 => 10 10 10
// 2+8 5+5 8+2 => 10 10 10
// 3+7 6+4 9+1 => 10 10 10
// Column-major data: [10, 10, 10, 10, 10, 10, 10, 10, 10]
assert_eq!(result.data(), &[10, 10, 10, 10, 10, 10, 10, 10, 10]); assert_eq!(result.data(), &[10, 10, 10, 10, 10, 10, 10, 10, 10]);
assert_eq!(result[(0, 0)], 10); assert_eq!(result[(0, 0)], 10);
assert_eq!(result[(1, 1)], 10); assert_eq!(result[(1, 1)], 10);
@ -1553,19 +1470,14 @@ mod tests {
#[test] #[test]
fn test_elementwise_sub() { fn test_elementwise_sub() {
let matrix1 = static_test_matrix(); // 3x3 let matrix1 = static_test_matrix();
let matrix2 = Matrix::from_vec(vec![1, 1, 1, 2, 2, 2, 3, 3, 3], 3, 3); // 3x3 let matrix2 = Matrix::from_vec(vec![1, 1, 1, 2, 2, 2, 3, 3, 3], 3, 3);
let result = &matrix1 - &matrix2; let result = &matrix1 - &matrix2;
assert_eq!(result.rows(), 3); assert_eq!(result.rows(), 3);
assert_eq!(result.cols(), 3); assert_eq!(result.cols(), 3);
// Expected:
// 1-1 4-2 7-3 => 0 2 4
// 2-1 5-2 8-3 => 1 3 5
// 3-1 6-2 9-3 => 2 4 6
// Column-major data: [0, 1, 2, 2, 3, 4, 4, 5, 6]
assert_eq!(result.data(), &[0, 1, 2, 2, 3, 4, 4, 5, 6]); assert_eq!(result.data(), &[0, 1, 2, 2, 3, 4, 4, 5, 6]);
assert_eq!(result[(0, 0)], 0); assert_eq!(result[(0, 0)], 0);
assert_eq!(result[(1, 1)], 3); assert_eq!(result[(1, 1)], 3);
@ -1574,19 +1486,15 @@ mod tests {
#[test] #[test]
fn test_elementwise_mul() { fn test_elementwise_mul() {
let matrix1 = static_test_matrix(); // 3x3 let matrix1 = static_test_matrix();
let matrix2 = Matrix::from_vec(vec![1, 2, 3, 1, 2, 3, 1, 2, 3], 3, 3); // 3x3 let matrix2 = Matrix::from_vec(vec![1, 2, 3, 1, 2, 3, 1, 2, 3], 3, 3);
let result = &matrix1 * &matrix2; let result = &matrix1 * &matrix2;
assert_eq!(result.rows(), 3); assert_eq!(result.rows(), 3);
assert_eq!(result.cols(), 3); assert_eq!(result.cols(), 3);
// Expected: // Expected
// 1*1 4*1 7*1 => 1 4 7
// 2*2 5*2 8*2 => 4 10 16
// 3*3 6*3 9*3 => 9 18 27
// Column-major data: [1, 4, 9, 4, 10, 18, 7, 16, 27]
assert_eq!(result.data(), &[1, 4, 9, 4, 10, 18, 7, 16, 27]); assert_eq!(result.data(), &[1, 4, 9, 4, 10, 18, 7, 16, 27]);
assert_eq!(result[(0, 0)], 1); assert_eq!(result[(0, 0)], 1);
assert_eq!(result[(1, 1)], 10); assert_eq!(result[(1, 1)], 10);
@ -1595,19 +1503,14 @@ mod tests {
#[test] #[test]
fn test_elementwise_div() { fn test_elementwise_div() {
let matrix1 = static_test_matrix(); // 3x3 let matrix1 = static_test_matrix();
let matrix2 = Matrix::from_vec(vec![1, 1, 1, 2, 2, 2, 7, 8, 9], 3, 3); // 3x3 let matrix2 = Matrix::from_vec(vec![1, 1, 1, 2, 2, 2, 7, 8, 9], 3, 3);
let result = &matrix1 / &matrix2; // Integer division let result = &matrix1 / &matrix2; // Integer division
assert_eq!(result.rows(), 3); assert_eq!(result.rows(), 3);
assert_eq!(result.cols(), 3); assert_eq!(result.cols(), 3);
// Expected:
// 1/1 4/2 7/7 => 1 2 1
// 2/1 5/2 8/8 => 2 2 1 (integer division)
// 3/1 6/2 9/9 => 3 3 1
// Column-major data: [1, 2, 3, 2, 2, 3, 1, 1, 1]
assert_eq!(result.data(), &[1, 2, 3, 2, 2, 3, 1, 1, 1]); assert_eq!(result.data(), &[1, 2, 3, 2, 2, 3, 1, 1, 1]);
assert_eq!(result[(0, 0)], 1); assert_eq!(result[(0, 0)], 1);
assert_eq!(result[(1, 1)], 2); assert_eq!(result[(1, 1)], 2);
@ -1617,34 +1520,26 @@ mod tests {
#[test] #[test]
#[should_panic(expected = "Row count mismatch: left has 3 rows, right has 2 rows")] #[should_panic(expected = "Row count mismatch: left has 3 rows, right has 2 rows")]
fn test_elementwise_op_row_mismatch() { fn test_elementwise_op_row_mismatch() {
let matrix1 = static_test_matrix(); // 3x3 let matrix1 = static_test_matrix();
let matrix2 = static_test_matrix_2x4(); // 2x4 let matrix2 = static_test_matrix_2x4();
let _ = &matrix1 + &matrix2; // Should panic let _ = &matrix1 + &matrix2; // Should panic
} }
#[test] #[test]
#[should_panic(expected = "Row count mismatch: left has 3 rows, right has 2 ro")] #[should_panic(expected = "Row count mismatch: left has 3 rows, right has 2 ro")]
fn test_elementwise_op_col_mismatch() { fn test_elementwise_op_col_mismatch() {
let matrix1 = static_test_matrix(); // 3x3 let matrix1 = static_test_matrix();
let matrix2 = static_test_matrix_2x4(); // 2x4 let matrix2 = static_test_matrix_2x4();
let _ = &matrix1 * &matrix2; // Should panic let _ = &matrix1 * &matrix2; // Should panic
} }
#[test] #[test]
fn test_bitwise_and() { fn test_bitwise_and() {
let data1 = vec![true, false, true, false, true, false]; // 2x3 let data1 = vec![true, false, true, false, true, false];
let data2 = vec![true, true, false, false, true, true]; // 2x3 let data2 = vec![true, true, false, false, true, true];
let matrix1 = BoolMatrix::from_vec(data1, 2, 3); let matrix1 = BoolMatrix::from_vec(data1, 2, 3);
let matrix2 = BoolMatrix::from_vec(data2, 2, 3); let matrix2 = BoolMatrix::from_vec(data2, 2, 3);
// Expected column-major results:
// T & T = T
// F & T = F
// T & F = F
// F & F = F
// T & T = T
// F & T = F
// Data: [T, F, F, F, T, F]
let expected_data = vec![true, false, false, false, true, false]; let expected_data = vec![true, false, false, false, true, false];
let expected_matrix = BoolMatrix::from_vec(expected_data, 2, 3); let expected_matrix = BoolMatrix::from_vec(expected_data, 2, 3);
@ -1654,19 +1549,11 @@ mod tests {
#[test] #[test]
fn test_bitwise_or() { fn test_bitwise_or() {
let data1 = vec![true, false, true, false, true, false]; // 2x3 let data1 = vec![true, false, true, false, true, false];
let data2 = vec![true, true, false, false, true, true]; // 2x3 let data2 = vec![true, true, false, false, true, true];
let matrix1 = BoolMatrix::from_vec(data1, 2, 3); let matrix1 = BoolMatrix::from_vec(data1, 2, 3);
let matrix2 = BoolMatrix::from_vec(data2, 2, 3); let matrix2 = BoolMatrix::from_vec(data2, 2, 3);
// Expected column-major results:
// T | T = T
// F | T = T
// T | F = T
// F | F = F
// T | T = T
// F | T = T
// Data: [T, T, T, F, T, T]
let expected_data = vec![true, true, true, false, true, true]; let expected_data = vec![true, true, true, false, true, true];
let expected_matrix = BoolMatrix::from_vec(expected_data, 2, 3); let expected_matrix = BoolMatrix::from_vec(expected_data, 2, 3);
@ -1676,19 +1563,11 @@ mod tests {
#[test] #[test]
fn test_bitwise_xor() { fn test_bitwise_xor() {
let data1 = vec![true, false, true, false, true, false]; // 2x3 let data1 = vec![true, false, true, false, true, false];
let data2 = vec![true, true, false, false, true, true]; // 2x3 let data2 = vec![true, true, false, false, true, true];
let matrix1 = BoolMatrix::from_vec(data1, 2, 3); let matrix1 = BoolMatrix::from_vec(data1, 2, 3);
let matrix2 = BoolMatrix::from_vec(data2, 2, 3); let matrix2 = BoolMatrix::from_vec(data2, 2, 3);
// Expected column-major results:
// T ^ T = F
// F ^ T = T
// T ^ F = T
// F ^ F = F
// T ^ T = F
// F ^ T = T
// Data: [F, T, T, F, F, T]
let expected_data = vec![false, true, true, false, false, true]; let expected_data = vec![false, true, true, false, false, true];
let expected_matrix = BoolMatrix::from_vec(expected_data, 2, 3); let expected_matrix = BoolMatrix::from_vec(expected_data, 2, 3);
@ -1698,17 +1577,9 @@ mod tests {
#[test] #[test]
fn test_bitwise_not() { fn test_bitwise_not() {
let data = vec![true, false, true, false, true, false]; // 2x3 let data = vec![true, false, true, false, true, false];
let matrix = BoolMatrix::from_vec(data, 2, 3); let matrix = BoolMatrix::from_vec(data, 2, 3);
// Expected column-major results:
// !T = F
// !F = T
// !T = F
// !F = T
// !T = F
// !F = T
// Data: [F, T, F, T, F, T]
let expected_data = vec![false, true, false, true, false, true]; let expected_data = vec![false, true, false, true, false, true];
let expected_matrix = BoolMatrix::from_vec(expected_data, 2, 3); let expected_matrix = BoolMatrix::from_vec(expected_data, 2, 3);
@ -1719,8 +1590,8 @@ mod tests {
#[test] #[test]
#[should_panic(expected = "Column count mismatch: left has 2 columns, right has 3 columns")] #[should_panic(expected = "Column count mismatch: left has 2 columns, right has 3 columns")]
fn test_bitwise_op_row_mismatch() { fn test_bitwise_op_row_mismatch() {
let data1 = vec![true, false, true, false]; // 2x2 let data1 = vec![true, false, true, false];
let data2 = vec![true, true, false, false, true, true]; // 2x3 let data2 = vec![true, true, false, false, true, true];
let matrix1 = BoolMatrix::from_vec(data1, 2, 2); let matrix1 = BoolMatrix::from_vec(data1, 2, 2);
let matrix2 = BoolMatrix::from_vec(data2, 2, 3); let matrix2 = BoolMatrix::from_vec(data2, 2, 3);
let _ = &matrix1 & &matrix2; // Should panic let _ = &matrix1 & &matrix2; // Should panic
@ -1729,8 +1600,8 @@ mod tests {
#[test] #[test]
#[should_panic(expected = "Column count mismatch: left has 2 columns, right has 3 columns")] #[should_panic(expected = "Column count mismatch: left has 2 columns, right has 3 columns")]
fn test_bitwise_op_col_mismatch() { fn test_bitwise_op_col_mismatch() {
let data1 = vec![true, false, true, false]; // 2x2 let data1 = vec![true, false, true, false];
let data2 = vec![true, true, false, false, true, true]; // 2x3 let data2 = vec![true, true, false, false, true, true];
let matrix1 = BoolMatrix::from_vec(data1, 2, 2); let matrix1 = BoolMatrix::from_vec(data1, 2, 2);
let matrix2 = BoolMatrix::from_vec(data2, 2, 3); let matrix2 = BoolMatrix::from_vec(data2, 2, 3);
let _ = &matrix1 | &matrix2; // Should panic let _ = &matrix1 | &matrix2; // Should panic
@ -1745,7 +1616,7 @@ mod tests {
"c".to_string(), "c".to_string(),
"d".to_string(), "d".to_string(),
]; ];
let matrix = StringMatrix::from_vec(data.clone(), 2, 2); // 2x2 let matrix = StringMatrix::from_vec(data.clone(), 2, 2);
assert_eq!(matrix[(0, 0)], "a".to_string()); assert_eq!(matrix[(0, 0)], "a".to_string());
assert_eq!(matrix[(1, 0)], "b".to_string()); assert_eq!(matrix[(1, 0)], "b".to_string());
@ -1757,12 +1628,10 @@ mod tests {
matrix[(0, 0)] = "hello".to_string(); matrix[(0, 0)] = "hello".to_string();
assert_eq!(matrix[(0, 0)], "hello".to_string()); assert_eq!(matrix[(0, 0)], "hello".to_string());
// Test add_column (requires Clone) // Test add_column
let new_col = vec!["e".to_string(), "f".to_string()]; let new_col = vec!["e".to_string(), "f".to_string()];
matrix.add_column(1, new_col); // Add at index 1 matrix.add_column(1, new_col);
// Should be:
// hello c d
// b e f
assert_eq!(matrix.rows(), 2); assert_eq!(matrix.rows(), 2);
assert_eq!(matrix.cols(), 3); assert_eq!(matrix.cols(), 3);
assert_eq!(matrix[(0, 0)], "hello".to_string()); assert_eq!(matrix[(0, 0)], "hello".to_string());
@ -1774,11 +1643,8 @@ mod tests {
// Test add_row (requires Clone) // Test add_row (requires Clone)
let new_row = vec!["g".to_string(), "h".to_string(), "i".to_string()]; let new_row = vec!["g".to_string(), "h".to_string(), "i".to_string()];
matrix.add_row(0, new_row); // Add at index 0 matrix.add_row(0, new_row);
// Should be:
// g h i
// hello e c
// b f d
assert_eq!(matrix.rows(), 3); assert_eq!(matrix.rows(), 3);
assert_eq!(matrix.cols(), 3); assert_eq!(matrix.cols(), 3);
assert_eq!(matrix[(0, 0)], "g".to_string()); assert_eq!(matrix[(0, 0)], "g".to_string());
@ -1790,8 +1656,8 @@ mod tests {
#[test] #[test]
fn test_float_matrix_ops() { fn test_float_matrix_ops() {
let data1 = vec![1.0, 2.0, 3.0, 4.0]; // 2x2 let data1 = vec![1.0, 2.0, 3.0, 4.0];
let data2 = vec![0.5, 1.5, 2.5, 3.5]; // 2x2 let data2 = vec![0.5, 1.5, 2.5, 3.5];
let matrix1 = FloatMatrix::from_vec(data1, 2, 2); let matrix1 = FloatMatrix::from_vec(data1, 2, 2);
let matrix2 = FloatMatrix::from_vec(data2, 2, 2); let matrix2 = FloatMatrix::from_vec(data2, 2, 2);
@ -1800,17 +1666,13 @@ mod tests {
let prod = &matrix1 * &matrix2; let prod = &matrix1 * &matrix2;
let div = &matrix1 / &matrix2; let div = &matrix1 / &matrix2;
// Check sums (col-major): [1.5, 3.5, 5.5, 7.5]
assert_eq!(sum.data(), &[1.5, 3.5, 5.5, 7.5]); assert_eq!(sum.data(), &[1.5, 3.5, 5.5, 7.5]);
// Check diffs (col-major): [0.5, 0.5, 0.5, 0.5]
assert_eq!(diff.data(), &[0.5, 0.5, 0.5, 0.5]); assert_eq!(diff.data(), &[0.5, 0.5, 0.5, 0.5]);
// Check prods (col-major): [0.5, 3.0, 7.5, 14.0]
assert_eq!(prod.data(), &[0.5, 3.0, 7.5, 14.0]); assert_eq!(prod.data(), &[0.5, 3.0, 7.5, 14.0]);
// Check divs (col-major): [2.0, 1.333..., 1.2, 1.14...] // Check divs (col-major): [2.0, 1.333..., 1.2, 1.14...]
// Using element access for more specific checks on floating point results
assert_eq!(div.rows(), 2); assert_eq!(div.rows(), 2);
assert_eq!(div.cols(), 2); assert_eq!(div.cols(), 2);
assert!((div[(0, 0)] - 1.0 / 0.5).abs() < 1e-9); // 2.0 assert!((div[(0, 0)] - 1.0 / 0.5).abs() < 1e-9); // 2.0
@ -1835,20 +1697,18 @@ mod tests {
// Perform the swap directly on the matrix // Perform the swap directly on the matrix
matrix.swap_columns(0, 2); // Swap column 0 and column 2 matrix.swap_columns(0, 2); // Swap column 0 and column 2
// --- Assertions --- // Verify the dimensions are unchanged
// 1. Verify the dimensions are unchanged
assert_eq!(matrix.rows(), 3, "Matrix rows should remain unchanged"); assert_eq!(matrix.rows(), 3, "Matrix rows should remain unchanged");
assert_eq!(matrix.cols(), 3, "Matrix cols should remain unchanged"); assert_eq!(matrix.cols(), 3, "Matrix cols should remain unchanged");
// 2. Verify the column that was NOT swapped is unchanged // Verify the column that was NOT swapped is unchanged
assert_eq!( assert_eq!(
matrix.column(1), matrix.column(1),
initial_col1_data.as_slice(), // Comparing slice to slice initial_col1_data.as_slice(), // Comparing slice to slice
"Column 1 data should be unchanged" "Column 1 data should be unchanged"
); );
// 3. Verify the data swap occurred correctly using the COLUMN ACCESSOR // Verify the data swap occurred correctly using the COLUMN ACCESSOR
// The data originally at index 0 should now be at index 2 // The data originally at index 0 should now be at index 2
assert_eq!( assert_eq!(
matrix.column(2), matrix.column(2),
@ -1862,16 +1722,14 @@ mod tests {
"Column 0 should now contain the original data from column 2" "Column 0 should now contain the original data from column 2"
); );
// 4. (Optional but useful) Verify the underlying raw data vector // Verify the underlying raw data vector
// Original data: [1, 2, 3, 4, 5, 6, 7, 8, 9]
// Expected data after swapping col 0 and col 2: [7, 8, 9, 4, 5, 6, 1, 2, 3]
assert_eq!( assert_eq!(
matrix.data(), matrix.data(),
&[7, 8, 9, 4, 5, 6, 1, 2, 3], &[7, 8, 9, 4, 5, 6, 1, 2, 3],
"Underlying data vector is incorrect after swap" "Underlying data vector is incorrect after swap"
); );
// 5. Test swapping with self (should be a no-op) // Test swapping with self (should be a no-op)
let state_before_self_swap = matrix.clone(); let state_before_self_swap = matrix.clone();
matrix.swap_columns(1, 1); matrix.swap_columns(1, 1);
assert_eq!( assert_eq!(
@ -1879,7 +1737,7 @@ mod tests {
"Swapping a column with itself should not change the matrix" "Swapping a column with itself should not change the matrix"
); );
// 6. Test swapping adjacent columns // Test swapping adjacent columns
let mut matrix2 = create_test_matrix_i32(); let mut matrix2 = create_test_matrix_i32();
let initial_col0_data_m2 = matrix2.column(0).to_vec(); let initial_col0_data_m2 = matrix2.column(0).to_vec();
let initial_col1_data_m2 = matrix2.column(1).to_vec(); let initial_col1_data_m2 = matrix2.column(1).to_vec();