diff --git a/src/matrix/mat.rs b/src/matrix/mat.rs index d1be2d5..ec61e66 100644 --- a/src/matrix/mat.rs +++ b/src/matrix/mat.rs @@ -624,7 +624,7 @@ mod tests { use super::{BoolMatrix, FloatMatrix, Matrix, StringMatrix}; // Helper function to create a basic Matrix for testing - fn create_test_matrix() -> Matrix { + fn static_test_matrix() -> Matrix { // Column-major data: // 1 4 7 // 2 5 8 @@ -634,7 +634,7 @@ mod tests { } // Another helper for a different size - fn create_test_matrix_2x4() -> Matrix { + fn static_test_matrix_2x4() -> Matrix { // Column-major data: // 1 3 5 7 // 2 4 6 8 @@ -733,7 +733,7 @@ mod tests { #[test] fn test_getters() { - let matrix = create_test_matrix(); + let matrix = static_test_matrix(); assert_eq!(matrix.rows(), 3); assert_eq!(matrix.cols(), 3); assert_eq!(matrix.data(), &[1, 2, 3, 4, 5, 6, 7, 8, 9]); @@ -741,7 +741,7 @@ mod tests { #[test] fn test_index_and_get() { - let matrix = create_test_matrix(); + let matrix = static_test_matrix(); assert_eq!(matrix[(0, 0)], 1); assert_eq!(matrix[(1, 1)], 5); assert_eq!(matrix[(2, 2)], 9); @@ -753,21 +753,21 @@ mod tests { #[test] #[should_panic(expected = "index out of bounds")] - fn test_index_out_of_bounds_row() { - let matrix = create_test_matrix(); // 3x3 + fn test_index_out_of_bounds_row_alt() { + let matrix = static_test_matrix(); // 3x3 let _ = matrix[(3, 0)]; } #[test] #[should_panic(expected = "index out of bounds")] - fn test_index_out_of_bounds_col() { - let matrix = create_test_matrix(); // 3x3 + fn test_index_out_of_bounds_col_alt() { + let matrix = static_test_matrix(); // 3x3 let _ = matrix[(0, 3)]; } #[test] fn test_index_mut_and_get_mut() { - let mut matrix = create_test_matrix(); // 3x3 + let mut matrix = static_test_matrix(); // 3x3 matrix[(0, 0)] = 10; matrix[(1, 1)] = 20; @@ -794,20 +794,20 @@ mod tests { #[test] #[should_panic(expected = "index out of bounds")] fn test_index_mut_out_of_bounds_row() { - let mut matrix = create_test_matrix(); // 3x3 + let mut matrix = static_test_matrix(); // 3x3 matrix[(3, 0)] = 99; } #[test] #[should_panic(expected = "index out of bounds")] fn test_index_mut_out_of_bounds_col() { - let mut matrix = create_test_matrix(); // 3x3 + let mut matrix = static_test_matrix(); // 3x3 matrix[(0, 3)] = 99; } #[test] fn test_column() { - let matrix = create_test_matrix_2x4(); // 2x4 + let matrix = static_test_matrix_2x4(); // 2x4 // 1 3 5 7 // 2 4 6 8 @@ -818,15 +818,15 @@ mod tests { } #[test] - #[should_panic(expected = "range end index")] + #[should_panic(expected = "column index 4 out of bounds for 4 columns")] fn test_column_out_of_bounds() { - let matrix = create_test_matrix_2x4(); // 2x4 + let matrix = static_test_matrix_2x4(); // 2x4 matrix.column(4); } #[test] fn test_column_mut() { - let mut matrix = create_test_matrix_2x4(); // 2x4 + let mut matrix = static_test_matrix_2x4(); // 2x4 // 1 3 5 7 // 2 4 6 8 @@ -851,15 +851,15 @@ mod tests { } #[test] - #[should_panic(expected = "range end index")] + #[should_panic(expected = "column index 4 out of bounds for 4 columns")] fn test_column_mut_out_of_bounds() { - let mut matrix = create_test_matrix_2x4(); // 2x4 + let mut matrix = static_test_matrix_2x4(); // 2x4 matrix.column_mut(4); } #[test] fn test_iter_columns() { - let matrix = create_test_matrix_2x4(); // 2x4 + let matrix = static_test_matrix_2x4(); // 2x4 // 1 3 5 7 // 2 4 6 8 @@ -873,7 +873,7 @@ mod tests { #[test] fn test_iter_rows() { - let matrix = create_test_matrix_2x4(); // 2x4 + let matrix = static_test_matrix_2x4(); // 2x4 // 1 3 5 7 // 2 4 6 8 @@ -889,7 +889,7 @@ mod tests { // test data_mut #[test] fn test_data_mut() { - let mut matrix = create_test_matrix(); // 3x3 + let mut matrix = static_test_matrix(); // 3x3 // 1 4 7 // 2 5 8 // 3 6 9 @@ -904,7 +904,7 @@ mod tests { #[test] fn test_matrix_row_get_and_iter() { - let matrix = create_test_matrix_2x4(); // 2x4 + let matrix = static_test_matrix_2x4(); // 2x4 // 1 3 5 7 // 2 4 6 8 @@ -924,7 +924,7 @@ mod tests { #[test] fn test_swap_columns() { - let mut matrix = create_test_matrix(); // 3x3 + let mut matrix = static_test_matrix(); // 3x3 // 1 4 7 // 2 5 8 // 3 6 9 @@ -958,15 +958,15 @@ mod tests { } #[test] - #[should_panic(expected = "column index out of bounds")] + #[should_panic(expected = "column index c2=3 out of bounds for 3 columns")] fn test_swap_columns_out_of_bounds() { - let mut matrix = create_test_matrix(); // 3x3 + let mut matrix = static_test_matrix(); // 3x3 matrix.swap_columns(0, 3); } #[test] fn test_delete_column() { - let mut matrix = create_test_matrix_2x4(); // 2x4 + let mut matrix = static_test_matrix_2x4(); // 2x4 // 1 3 5 7 // 2 4 6 8 @@ -1015,15 +1015,15 @@ mod tests { } #[test] - #[should_panic(expected = "column index out of bounds")] + #[should_panic(expected = "column index 4 out of bounds for 4 columns")] fn test_delete_column_out_of_bounds() { - let mut matrix = create_test_matrix_2x4(); // 2x4 + let mut matrix = static_test_matrix_2x4(); // 2x4 matrix.delete_column(4); } #[test] fn test_delete_row() { - let mut matrix = create_test_matrix(); // 3x3 + let mut matrix = static_test_matrix(); // 3x3 // 1 4 7 // 2 5 8 // 3 6 9 @@ -1065,15 +1065,15 @@ mod tests { } #[test] - #[should_panic(expected = "row index out of bounds")] + #[should_panic(expected = "row index 3 out of bounds for 3 rows")] fn test_delete_row_out_of_bounds() { - let mut matrix = create_test_matrix(); // 3x3 + let mut matrix = static_test_matrix(); // 3x3 matrix.delete_row(3); } #[test] fn test_add_column() { - let mut matrix = create_test_matrix_2x4(); // 2x4 + let mut matrix = static_test_matrix_2x4(); // 2x4 // 1 3 5 7 // 2 4 6 8 @@ -1131,9 +1131,9 @@ mod tests { } #[test] - #[should_panic(expected = "column index out of bounds")] + #[should_panic(expected = "add_column index 5 out of bounds for 4 columns")] fn test_add_column_out_of_bounds() { - let mut matrix = create_test_matrix_2x4(); // 2x4 + let mut matrix = static_test_matrix_2x4(); // 2x4 let new_col = vec![9, 10]; matrix.add_column(5, new_col); // Index 5 is out of bounds for 4 columns } @@ -1141,14 +1141,14 @@ mod tests { #[test] #[should_panic(expected = "column length mismatch")] fn test_add_column_length_mismatch() { - let mut matrix = create_test_matrix_2x4(); // 2x4 (2 rows) + let mut matrix = static_test_matrix_2x4(); // 2x4 (2 rows) let new_col = vec![9, 10, 11]; // Wrong length matrix.add_column(0, new_col); } #[test] fn test_add_row() { - let mut matrix = create_test_matrix_2x4(); // 2x4 + let mut matrix = static_test_matrix_2x4(); // 2x4 // 1 3 5 7 // 2 4 6 8 @@ -1197,8 +1197,9 @@ mod tests { assert_eq!(matrix.rows(), 4); assert_eq!(matrix.cols(), 4); assert_eq!(matrix[(0, 0)], 13); + assert_eq!(matrix[(0, 1)], 14); + assert_eq!(matrix[(0, 2)], 15); assert_eq!(matrix[(0, 3)], 16); - // Check some existing elements to ensure they shifted correctly assert_eq!(matrix[(1, 0)], 1); assert_eq!(matrix[(2, 1)], 10); assert_eq!(matrix[(3, 3)], 8); @@ -1219,9 +1220,9 @@ mod tests { } #[test] - #[should_panic(expected = "row index out of bounds")] + #[should_panic(expected = "add_row index 3 out of bounds for 2 rows")] fn test_add_row_out_of_bounds() { - let mut matrix = create_test_matrix_2x4(); // 2x4 + let mut matrix = static_test_matrix_2x4(); // 2x4 let new_row = vec![9, 10, 11, 12]; matrix.add_row(3, new_row); // Index 3 is out of bounds for 2 rows } @@ -1229,14 +1230,14 @@ mod tests { #[test] #[should_panic(expected = "row length mismatch")] fn test_add_row_length_mismatch() { - let mut matrix = create_test_matrix_2x4(); // 2x4 (4 cols) + let mut matrix = static_test_matrix_2x4(); // 2x4 (4 cols) let new_row = vec![9, 10, 11]; // Wrong length matrix.add_row(0, new_row); } #[test] fn test_elementwise_add() { - let matrix1 = create_test_matrix(); // 3x3 + let matrix1 = static_test_matrix(); // 3x3 let matrix2 = Matrix::from_vec(vec![9, 8, 7, 6, 5, 4, 3, 2, 1], 3, 3); // 3x3 let result = &matrix1 + &matrix2; @@ -1257,7 +1258,7 @@ mod tests { #[test] fn test_elementwise_sub() { - let matrix1 = create_test_matrix(); // 3x3 + let matrix1 = static_test_matrix(); // 3x3 let matrix2 = Matrix::from_vec(vec![1, 1, 1, 2, 2, 2, 3, 3, 3], 3, 3); // 3x3 let result = &matrix1 - &matrix2; @@ -1278,7 +1279,7 @@ mod tests { #[test] fn test_elementwise_mul() { - let matrix1 = create_test_matrix(); // 3x3 + let matrix1 = static_test_matrix(); // 3x3 let matrix2 = Matrix::from_vec(vec![1, 2, 3, 1, 2, 3, 1, 2, 3], 3, 3); // 3x3 let result = &matrix1 * &matrix2; @@ -1299,7 +1300,7 @@ mod tests { #[test] fn test_elementwise_div() { - let matrix1 = create_test_matrix(); // 3x3 + let matrix1 = static_test_matrix(); // 3x3 let matrix2 = Matrix::from_vec(vec![1, 1, 1, 2, 2, 2, 7, 8, 9], 3, 3); // 3x3 let result = &matrix1 / &matrix2; // Integer division @@ -1319,18 +1320,18 @@ mod tests { } #[test] - #[should_panic(expected = "row count mismatch")] + #[should_panic(expected = "Row count mismatch: left has 3 rows, right has 2 rows")] fn test_elementwise_op_row_mismatch() { - let matrix1 = create_test_matrix(); // 3x3 - let matrix2 = create_test_matrix_2x4(); // 2x4 + let matrix1 = static_test_matrix(); // 3x3 + let matrix2 = static_test_matrix_2x4(); // 2x4 let _ = &matrix1 + &matrix2; // Should panic } #[test] - #[should_panic(expected = "row count mismatch")] + #[should_panic(expected = "Row count mismatch: left has 3 rows, right has 2 ro")] fn test_elementwise_op_col_mismatch() { - let matrix1 = create_test_matrix(); // 3x3 - let matrix2 = create_test_matrix_2x4(); // 2x4 + let matrix1 = static_test_matrix(); // 3x3 + let matrix2 = static_test_matrix_2x4(); // 2x4 let _ = &matrix1 * &matrix2; // Should panic } @@ -1421,7 +1422,7 @@ mod tests { } #[test] - #[should_panic(expected = "col count mismatch")] + #[should_panic(expected = "Column count mismatch: left has 2 columns, right has 3 columns")] fn test_bitwise_op_row_mismatch() { let data1 = vec![true, false, true, false]; // 2x2 let data2 = vec![true, true, false, false, true, true]; // 2x3 @@ -1431,7 +1432,7 @@ mod tests { } #[test] - #[should_panic(expected = "col count mismatch")] + #[should_panic(expected = "Column count mismatch: left has 2 columns, right has 3 columns")] fn test_bitwise_op_col_mismatch() { let data1 = vec![true, false, true, false]; // 2x2 let data2 = vec![true, true, false, false, true, true]; // 2x3 @@ -1466,7 +1467,7 @@ mod tests { matrix.add_column(1, new_col); // Add at index 1 // Should be: // hello c d - // b e f + // b e f assert_eq!(matrix.rows(), 2); assert_eq!(matrix.cols(), 3); assert_eq!(matrix[(0, 0)], "hello".to_string()); @@ -1480,9 +1481,9 @@ mod tests { let new_row = vec!["g".to_string(), "h".to_string(), "i".to_string()]; matrix.add_row(0, new_row); // Add at index 0 // Should be: - // g h i + // g h i // hello e c - // b f d + // b f d assert_eq!(matrix.rows(), 3); assert_eq!(matrix.cols(), 3); assert_eq!(matrix[(0, 0)], "g".to_string()); @@ -1596,25 +1597,25 @@ mod tests { // Test broadcastable operations #[test] fn test_comparision_broadcast() { - let matrix = create_test_matrix(); + let matrix = static_test_matrix(); // test all > 0 - let result = matrix.gt_elementwise(0).as_vec(); + let result = matrix.gt_elem(0).into_vec(); let expected = vec![true; result.len()]; assert_eq!(result, expected); - let ma = create_test_matrix(); - let mb = create_test_matrix(); + let ma = static_test_matrix(); + let mb = static_test_matrix(); - let result = ma.eq_elementwise(mb); + let result = ma.eq_elem(mb); assert!(result.all()); - let result = matrix.lt_elementwise(1e10 as i32).all(); + let result = matrix.lt_elem(1e10 as i32).all(); assert!(result); for i in 0..matrix.rows() { for j in 0..matrix.cols() { let vx = matrix[(i, j)]; - let c = &(matrix.le_elementwise(vx)) & &(matrix.ge_elementwise(vx)); + let c = &(matrix.le_elem(vx)) & &(matrix.ge_elem(vx)); assert_eq!(c.count(), 1); } } @@ -1622,7 +1623,7 @@ mod tests { #[test] fn test_arithmetic_broadcast() { - let matrix = create_test_matrix(); + let matrix = static_test_matrix(); let result = &matrix + 1; for i in 0..matrix.rows() { for j in 0..matrix.cols() {