Refactor element-wise comparison implementation for matrices

This commit is contained in:
Palash Tyagi 2025-04-24 23:39:34 +01:00
parent 601c1c58d0
commit 21fc4a7241

View File

@ -184,27 +184,37 @@ impl<'a, T> MatrixRow<'a, T> {
} }
} }
macro_rules! impl_cmp { /// Implements elementwise eq, lt, le, gt and ge for `Matrix<T>`.
($name:ident, $op:tt) => { macro_rules! impl_elementwise_cmp {
pub fn $name(&self, rhs: &Matrix<T>) -> BoolMatrix { (
assert_eq!(self.rows, rhs.rows, ""); $( $method:ident => $op:tt ),* $(,)?
assert_eq!(self.cols, rhs.cols, ""); ) => {
impl<T: PartialOrd> Matrix<T> {
let data = self.data.iter() $(
.zip(rhs.data.iter()) #[doc = concat!("Elementwise comparison `x ", stringify!($op), " y`.")]
.map(|(a, b)| a $op b) pub fn $method(&self, rhs: &Matrix<T>) -> BoolMatrix {
.collect(); assert_eq!(self.rows, rhs.rows, "row count mismatch");
assert_eq!(self.cols, rhs.cols, "col count mismatch");
BoolMatrix::from_vec(data, self.rows, self.cols) let data = self
.data
.iter()
.zip(&rhs.data)
.map(|(a, b)| a $op b)
.collect();
BoolMatrix::from_vec(data, self.rows, self.cols)
}
)*
} }
}; };
} }
impl<T: PartialOrd> Matrix<T> { // Invoke it for all five operations:
impl_cmp!(lt_elementwise, <); impl_elementwise_cmp! {
impl_cmp!(le_elementwise, <=); eq_elementwise => ==,
impl_cmp!(gt_elementwise, >); lt_elementwise => <,
impl_cmp!(ge_elementwise, >=); le_elementwise => <=,
gt_elementwise => >,
ge_elementwise => >=,
} }
/// Generates element-wise arithmetic implementations for matrices. /// Generates element-wise arithmetic implementations for matrices.