Implement broadcasting trait for matrices and enhance element-wise comparison methods

This commit is contained in:
Palash Tyagi 2025-04-24 23:50:20 +01:00
parent 21fc4a7241
commit 56bb579530

View File

@ -183,24 +183,52 @@ impl<'a, T> MatrixRow<'a, T> {
(0..self.matrix.cols).map(move |c| &self.matrix[(self.row, c)])
}
}
/// A trait to turn either a Matrix<T> or a scalar T into a Vec<T> of
/// length rows*cols (broadcasting the scalar).
pub trait Broadcastable<T> {
fn to_vec(&self, rows: usize, cols: usize) -> Vec<T>;
}
/// Implements elementwise eq, lt, le, gt and ge for `Matrix<T>`.
impl<T: Clone> Broadcastable<T> for T {
fn to_vec(&self, rows: usize, cols: usize) -> Vec<T> {
vec![self.clone(); rows * cols]
}
}
impl<T: Clone> Broadcastable<T> for Matrix<T> {
fn to_vec(&self, rows: usize, cols: usize) -> Vec<T> {
assert_eq!(self.rows, rows, "row count mismatch");
assert_eq!(self.cols, cols, "col count mismatch");
self.data.clone()
}
}
/// Generates element-wise eq, lt, le, gt and ge methods
/// where the rhs can be a Matrix<T> or a scalar T.
macro_rules! impl_elementwise_cmp {
(
$( $method:ident => $op:tt ),* $(,)?
) => {
impl<T: PartialOrd> Matrix<T> {
impl<T: PartialOrd + Clone> Matrix<T> {
$(
#[doc = concat!("Elementwise comparison `x ", stringify!($op), " y`.")]
pub fn $method(&self, rhs: &Matrix<T>) -> BoolMatrix {
assert_eq!(self.rows, rhs.rows, "row count mismatch");
assert_eq!(self.cols, rhs.cols, "col count mismatch");
#[doc = concat!("Element-wise comparison `self ", stringify!($op), " rhs`,\n\
where `rhs` may be a Matrix<T> or a scalar T.")]
pub fn $method<Rhs>(&self, rhs: Rhs) -> BoolMatrix
where
Rhs: Broadcastable<T>,
{
// Prepare broadcasted rhs-data
let rhs_data = rhs.to_vec(self.rows, self.cols);
// Pairwise compare
let data = self
.data
.iter()
.zip(&rhs.data)
.cloned()
.zip(rhs_data.into_iter())
.map(|(a, b)| a $op b)
.collect();
BoolMatrix::from_vec(data, self.rows, self.cols)
}
)*
@ -208,7 +236,7 @@ macro_rules! impl_elementwise_cmp {
};
}
// Invoke it for all five operations:
// Instantiate element-wise comparison implementations for matrices.
impl_elementwise_cmp! {
eq_elementwise => ==,
lt_elementwise => <,