diff --git a/src/matrix/mat.rs b/src/matrix/mat.rs index 6006340..400facc 100644 --- a/src/matrix/mat.rs +++ b/src/matrix/mat.rs @@ -183,24 +183,52 @@ impl<'a, T> MatrixRow<'a, T> { (0..self.matrix.cols).map(move |c| &self.matrix[(self.row, c)]) } } +/// A trait to turn either a Matrix or a scalar T into a Vec of +/// length rows*cols (broadcasting the scalar). +pub trait Broadcastable { + fn to_vec(&self, rows: usize, cols: usize) -> Vec; +} -/// Implements element‐wise eq, lt, le, gt and ge for `Matrix`. +impl Broadcastable for T { + fn to_vec(&self, rows: usize, cols: usize) -> Vec { + vec![self.clone(); rows * cols] + } +} + +impl Broadcastable for Matrix { + fn to_vec(&self, rows: usize, cols: usize) -> Vec { + assert_eq!(self.rows, rows, "row count mismatch"); + assert_eq!(self.cols, cols, "col count mismatch"); + self.data.clone() + } +} + +/// Generates element-wise eq, lt, le, gt and ge methods +/// where the rhs can be a Matrix or a scalar T. macro_rules! impl_elementwise_cmp { ( $( $method:ident => $op:tt ),* $(,)? ) => { - impl Matrix { + impl Matrix { $( - #[doc = concat!("Element‐wise comparison `x ", stringify!($op), " y`.")] - pub fn $method(&self, rhs: &Matrix) -> BoolMatrix { - assert_eq!(self.rows, rhs.rows, "row count mismatch"); - assert_eq!(self.cols, rhs.cols, "col count mismatch"); + #[doc = concat!("Element-wise comparison `self ", stringify!($op), " rhs`,\n\ + where `rhs` may be a Matrix or a scalar T.")] + pub fn $method(&self, rhs: Rhs) -> BoolMatrix + where + Rhs: Broadcastable, + { + // Prepare broadcasted rhs-data + let rhs_data = rhs.to_vec(self.rows, self.cols); + + // Pairwise compare let data = self .data .iter() - .zip(&rhs.data) + .cloned() + .zip(rhs_data.into_iter()) .map(|(a, b)| a $op b) .collect(); + BoolMatrix::from_vec(data, self.rows, self.cols) } )* @@ -208,7 +236,7 @@ macro_rules! impl_elementwise_cmp { }; } -// Invoke it for all five operations: +// Instantiate element-wise comparison implementations for matrices. impl_elementwise_cmp! { eq_elementwise => ==, lt_elementwise => <,