Refactor element-wise comparison implementation for matrices

This commit is contained in:
Palash Tyagi 2025-04-24 23:33:04 +01:00
parent ea8e1c8471
commit 601c1c58d0

View File

@ -184,30 +184,27 @@ impl<'a, T> MatrixRow<'a, T> {
}
}
// PartialEq for element-wise comparison
impl<T: PartialEq> Matrix<T> {
/// Performs element-wise equality comparison with another matrix.
/// Returns a new `Matrix<bool>` where each element is the result of comparing the corresponding elements.
/// Panics if the matrices have different dimensions.
pub fn eq_elementwise(&self, rhs: &Matrix<T>) -> BoolMatrix {
assert_eq!(
self.rows, rhs.rows,
"Matrices must have the same number of rows for element-wise comparison"
);
assert_eq!(
self.cols, rhs.cols,
"Matrices must have the same number of columns for element-wise comparison"
);
macro_rules! impl_cmp {
($name:ident, $op:tt) => {
pub fn $name(&self, rhs: &Matrix<T>) -> BoolMatrix {
assert_eq!(self.rows, rhs.rows, "");
assert_eq!(self.cols, rhs.cols, "");
let data = self
.data
.iter()
.zip(rhs.data.iter())
.map(|(a, b)| a == b) // Use T::PartialEq::eq
.collect();
let data = self.data.iter()
.zip(rhs.data.iter())
.map(|(a, b)| a $op b)
.collect();
BoolMatrix::from_vec(data, self.rows, self.cols)
}
BoolMatrix::from_vec(data, self.rows, self.cols)
}
};
}
impl<T: PartialOrd> Matrix<T> {
impl_cmp!(lt_elementwise, <);
impl_cmp!(le_elementwise, <=);
impl_cmp!(gt_elementwise, >);
impl_cmp!(ge_elementwise, >=);
}
/// Generates element-wise arithmetic implementations for matrices.