Refactor element-wise comparison implementation for matrices

This commit is contained in:
Palash Tyagi 2025-04-24 23:39:34 +01:00
parent 601c1c58d0
commit 21fc4a7241

View File

@ -184,27 +184,37 @@ impl<'a, T> MatrixRow<'a, T> {
}
}
macro_rules! impl_cmp {
($name:ident, $op:tt) => {
pub fn $name(&self, rhs: &Matrix<T>) -> BoolMatrix {
assert_eq!(self.rows, rhs.rows, "");
assert_eq!(self.cols, rhs.cols, "");
let data = self.data.iter()
.zip(rhs.data.iter())
/// Implements elementwise eq, lt, le, gt and ge for `Matrix<T>`.
macro_rules! impl_elementwise_cmp {
(
$( $method:ident => $op:tt ),* $(,)?
) => {
impl<T: PartialOrd> Matrix<T> {
$(
#[doc = concat!("Elementwise comparison `x ", stringify!($op), " y`.")]
pub fn $method(&self, rhs: &Matrix<T>) -> BoolMatrix {
assert_eq!(self.rows, rhs.rows, "row count mismatch");
assert_eq!(self.cols, rhs.cols, "col count mismatch");
let data = self
.data
.iter()
.zip(&rhs.data)
.map(|(a, b)| a $op b)
.collect();
BoolMatrix::from_vec(data, self.rows, self.cols)
}
)*
}
};
}
impl<T: PartialOrd> Matrix<T> {
impl_cmp!(lt_elementwise, <);
impl_cmp!(le_elementwise, <=);
impl_cmp!(gt_elementwise, >);
impl_cmp!(ge_elementwise, >=);
// Invoke it for all five operations:
impl_elementwise_cmp! {
eq_elementwise => ==,
lt_elementwise => <,
le_elementwise => <=,
gt_elementwise => >,
ge_elementwise => >=,
}
/// Generates element-wise arithmetic implementations for matrices.