diff --git a/src/matrix/mat.rs b/src/matrix/mat.rs index 5670267..c04873c 100644 --- a/src/matrix/mat.rs +++ b/src/matrix/mat.rs @@ -184,6 +184,32 @@ impl<'a, T> MatrixRow<'a, T> { } } +// PartialEq for element-wise comparison +impl Matrix { + /// Performs element-wise equality comparison with another matrix. + /// Returns a new `Matrix` where each element is the result of comparing the corresponding elements. + /// Panics if the matrices have different dimensions. + pub fn eq_elementwise(&self, rhs: &Matrix) -> BoolMatrix { + assert_eq!( + self.rows, rhs.rows, + "Matrices must have the same number of rows for element-wise comparison" + ); + assert_eq!( + self.cols, rhs.cols, + "Matrices must have the same number of columns for element-wise comparison" + ); + + let data = self + .data + .iter() + .zip(rhs.data.iter()) + .map(|(a, b)| a == b) // Use T::PartialEq::eq + .collect(); + + BoolMatrix::from_vec(data, self.rows, self.cols) + } +} + /// Generates element-wise arithmetic implementations for matrices. macro_rules! impl_elementwise_op { ($OpTrait:ident, $method:ident, $op:tt) => { @@ -1161,8 +1187,7 @@ mod tests { assert!((div[(0, 1)] - 3.0 / 2.5).abs() < 1e-9); // 1.2 assert!((div[(1, 1)] - 4.0 / 3.5).abs() < 1e-9); // 1.14... } - - + fn create_test_matrix_i32() -> Matrix { Matrix::from_cols(vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]]) }