Rename test matrix creation functions for clarity and consistency

This commit is contained in:
Palash Tyagi 2025-05-03 01:16:33 +01:00
parent 3cb68be062
commit 73a30d45c5

View File

@ -624,7 +624,7 @@ mod tests {
use super::{BoolMatrix, FloatMatrix, Matrix, StringMatrix};
// Helper function to create a basic Matrix for testing
fn create_test_matrix() -> Matrix<i32> {
fn static_test_matrix() -> Matrix<i32> {
// Column-major data:
// 1 4 7
// 2 5 8
@ -634,7 +634,7 @@ mod tests {
}
// Another helper for a different size
fn create_test_matrix_2x4() -> Matrix<i32> {
fn static_test_matrix_2x4() -> Matrix<i32> {
// Column-major data:
// 1 3 5 7
// 2 4 6 8
@ -733,7 +733,7 @@ mod tests {
#[test]
fn test_getters() {
let matrix = create_test_matrix();
let matrix = static_test_matrix();
assert_eq!(matrix.rows(), 3);
assert_eq!(matrix.cols(), 3);
assert_eq!(matrix.data(), &[1, 2, 3, 4, 5, 6, 7, 8, 9]);
@ -741,7 +741,7 @@ mod tests {
#[test]
fn test_index_and_get() {
let matrix = create_test_matrix();
let matrix = static_test_matrix();
assert_eq!(matrix[(0, 0)], 1);
assert_eq!(matrix[(1, 1)], 5);
assert_eq!(matrix[(2, 2)], 9);
@ -753,21 +753,21 @@ mod tests {
#[test]
#[should_panic(expected = "index out of bounds")]
fn test_index_out_of_bounds_row() {
let matrix = create_test_matrix(); // 3x3
fn test_index_out_of_bounds_row_alt() {
let matrix = static_test_matrix(); // 3x3
let _ = matrix[(3, 0)];
}
#[test]
#[should_panic(expected = "index out of bounds")]
fn test_index_out_of_bounds_col() {
let matrix = create_test_matrix(); // 3x3
fn test_index_out_of_bounds_col_alt() {
let matrix = static_test_matrix(); // 3x3
let _ = matrix[(0, 3)];
}
#[test]
fn test_index_mut_and_get_mut() {
let mut matrix = create_test_matrix(); // 3x3
let mut matrix = static_test_matrix(); // 3x3
matrix[(0, 0)] = 10;
matrix[(1, 1)] = 20;
@ -794,20 +794,20 @@ mod tests {
#[test]
#[should_panic(expected = "index out of bounds")]
fn test_index_mut_out_of_bounds_row() {
let mut matrix = create_test_matrix(); // 3x3
let mut matrix = static_test_matrix(); // 3x3
matrix[(3, 0)] = 99;
}
#[test]
#[should_panic(expected = "index out of bounds")]
fn test_index_mut_out_of_bounds_col() {
let mut matrix = create_test_matrix(); // 3x3
let mut matrix = static_test_matrix(); // 3x3
matrix[(0, 3)] = 99;
}
#[test]
fn test_column() {
let matrix = create_test_matrix_2x4(); // 2x4
let matrix = static_test_matrix_2x4(); // 2x4
// 1 3 5 7
// 2 4 6 8
@ -818,15 +818,15 @@ mod tests {
}
#[test]
#[should_panic(expected = "range end index")]
#[should_panic(expected = "column index 4 out of bounds for 4 columns")]
fn test_column_out_of_bounds() {
let matrix = create_test_matrix_2x4(); // 2x4
let matrix = static_test_matrix_2x4(); // 2x4
matrix.column(4);
}
#[test]
fn test_column_mut() {
let mut matrix = create_test_matrix_2x4(); // 2x4
let mut matrix = static_test_matrix_2x4(); // 2x4
// 1 3 5 7
// 2 4 6 8
@ -851,15 +851,15 @@ mod tests {
}
#[test]
#[should_panic(expected = "range end index")]
#[should_panic(expected = "column index 4 out of bounds for 4 columns")]
fn test_column_mut_out_of_bounds() {
let mut matrix = create_test_matrix_2x4(); // 2x4
let mut matrix = static_test_matrix_2x4(); // 2x4
matrix.column_mut(4);
}
#[test]
fn test_iter_columns() {
let matrix = create_test_matrix_2x4(); // 2x4
let matrix = static_test_matrix_2x4(); // 2x4
// 1 3 5 7
// 2 4 6 8
@ -873,7 +873,7 @@ mod tests {
#[test]
fn test_iter_rows() {
let matrix = create_test_matrix_2x4(); // 2x4
let matrix = static_test_matrix_2x4(); // 2x4
// 1 3 5 7
// 2 4 6 8
@ -889,7 +889,7 @@ mod tests {
// test data_mut
#[test]
fn test_data_mut() {
let mut matrix = create_test_matrix(); // 3x3
let mut matrix = static_test_matrix(); // 3x3
// 1 4 7
// 2 5 8
// 3 6 9
@ -904,7 +904,7 @@ mod tests {
#[test]
fn test_matrix_row_get_and_iter() {
let matrix = create_test_matrix_2x4(); // 2x4
let matrix = static_test_matrix_2x4(); // 2x4
// 1 3 5 7
// 2 4 6 8
@ -924,7 +924,7 @@ mod tests {
#[test]
fn test_swap_columns() {
let mut matrix = create_test_matrix(); // 3x3
let mut matrix = static_test_matrix(); // 3x3
// 1 4 7
// 2 5 8
// 3 6 9
@ -958,15 +958,15 @@ mod tests {
}
#[test]
#[should_panic(expected = "column index out of bounds")]
#[should_panic(expected = "column index c2=3 out of bounds for 3 columns")]
fn test_swap_columns_out_of_bounds() {
let mut matrix = create_test_matrix(); // 3x3
let mut matrix = static_test_matrix(); // 3x3
matrix.swap_columns(0, 3);
}
#[test]
fn test_delete_column() {
let mut matrix = create_test_matrix_2x4(); // 2x4
let mut matrix = static_test_matrix_2x4(); // 2x4
// 1 3 5 7
// 2 4 6 8
@ -1015,15 +1015,15 @@ mod tests {
}
#[test]
#[should_panic(expected = "column index out of bounds")]
#[should_panic(expected = "column index 4 out of bounds for 4 columns")]
fn test_delete_column_out_of_bounds() {
let mut matrix = create_test_matrix_2x4(); // 2x4
let mut matrix = static_test_matrix_2x4(); // 2x4
matrix.delete_column(4);
}
#[test]
fn test_delete_row() {
let mut matrix = create_test_matrix(); // 3x3
let mut matrix = static_test_matrix(); // 3x3
// 1 4 7
// 2 5 8
// 3 6 9
@ -1065,15 +1065,15 @@ mod tests {
}
#[test]
#[should_panic(expected = "row index out of bounds")]
#[should_panic(expected = "row index 3 out of bounds for 3 rows")]
fn test_delete_row_out_of_bounds() {
let mut matrix = create_test_matrix(); // 3x3
let mut matrix = static_test_matrix(); // 3x3
matrix.delete_row(3);
}
#[test]
fn test_add_column() {
let mut matrix = create_test_matrix_2x4(); // 2x4
let mut matrix = static_test_matrix_2x4(); // 2x4
// 1 3 5 7
// 2 4 6 8
@ -1131,9 +1131,9 @@ mod tests {
}
#[test]
#[should_panic(expected = "column index out of bounds")]
#[should_panic(expected = "add_column index 5 out of bounds for 4 columns")]
fn test_add_column_out_of_bounds() {
let mut matrix = create_test_matrix_2x4(); // 2x4
let mut matrix = static_test_matrix_2x4(); // 2x4
let new_col = vec![9, 10];
matrix.add_column(5, new_col); // Index 5 is out of bounds for 4 columns
}
@ -1141,14 +1141,14 @@ mod tests {
#[test]
#[should_panic(expected = "column length mismatch")]
fn test_add_column_length_mismatch() {
let mut matrix = create_test_matrix_2x4(); // 2x4 (2 rows)
let mut matrix = static_test_matrix_2x4(); // 2x4 (2 rows)
let new_col = vec![9, 10, 11]; // Wrong length
matrix.add_column(0, new_col);
}
#[test]
fn test_add_row() {
let mut matrix = create_test_matrix_2x4(); // 2x4
let mut matrix = static_test_matrix_2x4(); // 2x4
// 1 3 5 7
// 2 4 6 8
@ -1197,8 +1197,9 @@ mod tests {
assert_eq!(matrix.rows(), 4);
assert_eq!(matrix.cols(), 4);
assert_eq!(matrix[(0, 0)], 13);
assert_eq!(matrix[(0, 1)], 14);
assert_eq!(matrix[(0, 2)], 15);
assert_eq!(matrix[(0, 3)], 16);
// Check some existing elements to ensure they shifted correctly
assert_eq!(matrix[(1, 0)], 1);
assert_eq!(matrix[(2, 1)], 10);
assert_eq!(matrix[(3, 3)], 8);
@ -1219,9 +1220,9 @@ mod tests {
}
#[test]
#[should_panic(expected = "row index out of bounds")]
#[should_panic(expected = "add_row index 3 out of bounds for 2 rows")]
fn test_add_row_out_of_bounds() {
let mut matrix = create_test_matrix_2x4(); // 2x4
let mut matrix = static_test_matrix_2x4(); // 2x4
let new_row = vec![9, 10, 11, 12];
matrix.add_row(3, new_row); // Index 3 is out of bounds for 2 rows
}
@ -1229,14 +1230,14 @@ mod tests {
#[test]
#[should_panic(expected = "row length mismatch")]
fn test_add_row_length_mismatch() {
let mut matrix = create_test_matrix_2x4(); // 2x4 (4 cols)
let mut matrix = static_test_matrix_2x4(); // 2x4 (4 cols)
let new_row = vec![9, 10, 11]; // Wrong length
matrix.add_row(0, new_row);
}
#[test]
fn test_elementwise_add() {
let matrix1 = create_test_matrix(); // 3x3
let matrix1 = static_test_matrix(); // 3x3
let matrix2 = Matrix::from_vec(vec![9, 8, 7, 6, 5, 4, 3, 2, 1], 3, 3); // 3x3
let result = &matrix1 + &matrix2;
@ -1257,7 +1258,7 @@ mod tests {
#[test]
fn test_elementwise_sub() {
let matrix1 = create_test_matrix(); // 3x3
let matrix1 = static_test_matrix(); // 3x3
let matrix2 = Matrix::from_vec(vec![1, 1, 1, 2, 2, 2, 3, 3, 3], 3, 3); // 3x3
let result = &matrix1 - &matrix2;
@ -1278,7 +1279,7 @@ mod tests {
#[test]
fn test_elementwise_mul() {
let matrix1 = create_test_matrix(); // 3x3
let matrix1 = static_test_matrix(); // 3x3
let matrix2 = Matrix::from_vec(vec![1, 2, 3, 1, 2, 3, 1, 2, 3], 3, 3); // 3x3
let result = &matrix1 * &matrix2;
@ -1299,7 +1300,7 @@ mod tests {
#[test]
fn test_elementwise_div() {
let matrix1 = create_test_matrix(); // 3x3
let matrix1 = static_test_matrix(); // 3x3
let matrix2 = Matrix::from_vec(vec![1, 1, 1, 2, 2, 2, 7, 8, 9], 3, 3); // 3x3
let result = &matrix1 / &matrix2; // Integer division
@ -1319,18 +1320,18 @@ mod tests {
}
#[test]
#[should_panic(expected = "row count mismatch")]
#[should_panic(expected = "Row count mismatch: left has 3 rows, right has 2 rows")]
fn test_elementwise_op_row_mismatch() {
let matrix1 = create_test_matrix(); // 3x3
let matrix2 = create_test_matrix_2x4(); // 2x4
let matrix1 = static_test_matrix(); // 3x3
let matrix2 = static_test_matrix_2x4(); // 2x4
let _ = &matrix1 + &matrix2; // Should panic
}
#[test]
#[should_panic(expected = "row count mismatch")]
#[should_panic(expected = "Row count mismatch: left has 3 rows, right has 2 ro")]
fn test_elementwise_op_col_mismatch() {
let matrix1 = create_test_matrix(); // 3x3
let matrix2 = create_test_matrix_2x4(); // 2x4
let matrix1 = static_test_matrix(); // 3x3
let matrix2 = static_test_matrix_2x4(); // 2x4
let _ = &matrix1 * &matrix2; // Should panic
}
@ -1421,7 +1422,7 @@ mod tests {
}
#[test]
#[should_panic(expected = "col count mismatch")]
#[should_panic(expected = "Column count mismatch: left has 2 columns, right has 3 columns")]
fn test_bitwise_op_row_mismatch() {
let data1 = vec![true, false, true, false]; // 2x2
let data2 = vec![true, true, false, false, true, true]; // 2x3
@ -1431,7 +1432,7 @@ mod tests {
}
#[test]
#[should_panic(expected = "col count mismatch")]
#[should_panic(expected = "Column count mismatch: left has 2 columns, right has 3 columns")]
fn test_bitwise_op_col_mismatch() {
let data1 = vec![true, false, true, false]; // 2x2
let data2 = vec![true, true, false, false, true, true]; // 2x3
@ -1466,7 +1467,7 @@ mod tests {
matrix.add_column(1, new_col); // Add at index 1
// Should be:
// hello c d
// b e f
// b e f
assert_eq!(matrix.rows(), 2);
assert_eq!(matrix.cols(), 3);
assert_eq!(matrix[(0, 0)], "hello".to_string());
@ -1480,9 +1481,9 @@ mod tests {
let new_row = vec!["g".to_string(), "h".to_string(), "i".to_string()];
matrix.add_row(0, new_row); // Add at index 0
// Should be:
// g h i
// g h i
// hello e c
// b f d
// b f d
assert_eq!(matrix.rows(), 3);
assert_eq!(matrix.cols(), 3);
assert_eq!(matrix[(0, 0)], "g".to_string());
@ -1596,25 +1597,25 @@ mod tests {
// Test broadcastable operations
#[test]
fn test_comparision_broadcast() {
let matrix = create_test_matrix();
let matrix = static_test_matrix();
// test all > 0
let result = matrix.gt_elementwise(0).as_vec();
let result = matrix.gt_elem(0).into_vec();
let expected = vec![true; result.len()];
assert_eq!(result, expected);
let ma = create_test_matrix();
let mb = create_test_matrix();
let ma = static_test_matrix();
let mb = static_test_matrix();
let result = ma.eq_elementwise(mb);
let result = ma.eq_elem(mb);
assert!(result.all());
let result = matrix.lt_elementwise(1e10 as i32).all();
let result = matrix.lt_elem(1e10 as i32).all();
assert!(result);
for i in 0..matrix.rows() {
for j in 0..matrix.cols() {
let vx = matrix[(i, j)];
let c = &(matrix.le_elementwise(vx)) & &(matrix.ge_elementwise(vx));
let c = &(matrix.le_elem(vx)) & &(matrix.ge_elem(vx));
assert_eq!(c.count(), 1);
}
}
@ -1622,7 +1623,7 @@ mod tests {
#[test]
fn test_arithmetic_broadcast() {
let matrix = create_test_matrix();
let matrix = static_test_matrix();
let result = &matrix + 1;
for i in 0..matrix.rows() {
for j in 0..matrix.cols() {