Implement element-wise equality comparison for matrices

This commit is contained in:
Palash Tyagi 2025-04-24 23:19:02 +01:00
parent 415bb4a61e
commit ea8e1c8471

View File

@ -184,6 +184,32 @@ impl<'a, T> MatrixRow<'a, T> {
}
}
// PartialEq for element-wise comparison
impl<T: PartialEq> Matrix<T> {
/// Performs element-wise equality comparison with another matrix.
/// Returns a new `Matrix<bool>` where each element is the result of comparing the corresponding elements.
/// Panics if the matrices have different dimensions.
pub fn eq_elementwise(&self, rhs: &Matrix<T>) -> BoolMatrix {
assert_eq!(
self.rows, rhs.rows,
"Matrices must have the same number of rows for element-wise comparison"
);
assert_eq!(
self.cols, rhs.cols,
"Matrices must have the same number of columns for element-wise comparison"
);
let data = self
.data
.iter()
.zip(rhs.data.iter())
.map(|(a, b)| a == b) // Use T::PartialEq::eq
.collect();
BoolMatrix::from_vec(data, self.rows, self.cols)
}
}
/// Generates element-wise arithmetic implementations for matrices.
macro_rules! impl_elementwise_op {
($OpTrait:ident, $method:ident, $op:tt) => {
@ -1162,7 +1188,6 @@ mod tests {
assert!((div[(1, 1)] - 4.0 / 3.5).abs() < 1e-9); // 1.14...
}
fn create_test_matrix_i32() -> Matrix<i32> {
Matrix::from_cols(vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]])
}