Implement broadcasting trait for matrices and enhance element-wise comparison methods

This commit is contained in:
Palash Tyagi 2025-04-24 23:50:20 +01:00
parent 21fc4a7241
commit 56bb579530

View File

@ -183,24 +183,52 @@ impl<'a, T> MatrixRow<'a, T> {
(0..self.matrix.cols).map(move |c| &self.matrix[(self.row, c)]) (0..self.matrix.cols).map(move |c| &self.matrix[(self.row, c)])
} }
} }
/// A trait to turn either a Matrix<T> or a scalar T into a Vec<T> of
/// length rows*cols (broadcasting the scalar).
pub trait Broadcastable<T> {
fn to_vec(&self, rows: usize, cols: usize) -> Vec<T>;
}
/// Implements elementwise eq, lt, le, gt and ge for `Matrix<T>`. impl<T: Clone> Broadcastable<T> for T {
fn to_vec(&self, rows: usize, cols: usize) -> Vec<T> {
vec![self.clone(); rows * cols]
}
}
impl<T: Clone> Broadcastable<T> for Matrix<T> {
fn to_vec(&self, rows: usize, cols: usize) -> Vec<T> {
assert_eq!(self.rows, rows, "row count mismatch");
assert_eq!(self.cols, cols, "col count mismatch");
self.data.clone()
}
}
/// Generates element-wise eq, lt, le, gt and ge methods
/// where the rhs can be a Matrix<T> or a scalar T.
macro_rules! impl_elementwise_cmp { macro_rules! impl_elementwise_cmp {
( (
$( $method:ident => $op:tt ),* $(,)? $( $method:ident => $op:tt ),* $(,)?
) => { ) => {
impl<T: PartialOrd> Matrix<T> { impl<T: PartialOrd + Clone> Matrix<T> {
$( $(
#[doc = concat!("Elementwise comparison `x ", stringify!($op), " y`.")] #[doc = concat!("Element-wise comparison `self ", stringify!($op), " rhs`,\n\
pub fn $method(&self, rhs: &Matrix<T>) -> BoolMatrix { where `rhs` may be a Matrix<T> or a scalar T.")]
assert_eq!(self.rows, rhs.rows, "row count mismatch"); pub fn $method<Rhs>(&self, rhs: Rhs) -> BoolMatrix
assert_eq!(self.cols, rhs.cols, "col count mismatch"); where
Rhs: Broadcastable<T>,
{
// Prepare broadcasted rhs-data
let rhs_data = rhs.to_vec(self.rows, self.cols);
// Pairwise compare
let data = self let data = self
.data .data
.iter() .iter()
.zip(&rhs.data) .cloned()
.zip(rhs_data.into_iter())
.map(|(a, b)| a $op b) .map(|(a, b)| a $op b)
.collect(); .collect();
BoolMatrix::from_vec(data, self.rows, self.cols) BoolMatrix::from_vec(data, self.rows, self.cols)
} }
)* )*
@ -208,7 +236,7 @@ macro_rules! impl_elementwise_cmp {
}; };
} }
// Invoke it for all five operations: // Instantiate element-wise comparison implementations for matrices.
impl_elementwise_cmp! { impl_elementwise_cmp! {
eq_elementwise => ==, eq_elementwise => ==,
lt_elementwise => <, lt_elementwise => <,