mirror of
https://github.com/Magnus167/rustframe.git
synced 2025-08-20 04:19:59 +00:00
Implement broadcasting trait for matrices and enhance element-wise comparison methods
This commit is contained in:
parent
21fc4a7241
commit
56bb579530
@ -183,24 +183,52 @@ impl<'a, T> MatrixRow<'a, T> {
|
||||
(0..self.matrix.cols).map(move |c| &self.matrix[(self.row, c)])
|
||||
}
|
||||
}
|
||||
/// A trait to turn either a Matrix<T> or a scalar T into a Vec<T> of
|
||||
/// length rows*cols (broadcasting the scalar).
|
||||
pub trait Broadcastable<T> {
|
||||
fn to_vec(&self, rows: usize, cols: usize) -> Vec<T>;
|
||||
}
|
||||
|
||||
/// Implements element‐wise eq, lt, le, gt and ge for `Matrix<T>`.
|
||||
impl<T: Clone> Broadcastable<T> for T {
|
||||
fn to_vec(&self, rows: usize, cols: usize) -> Vec<T> {
|
||||
vec![self.clone(); rows * cols]
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Clone> Broadcastable<T> for Matrix<T> {
|
||||
fn to_vec(&self, rows: usize, cols: usize) -> Vec<T> {
|
||||
assert_eq!(self.rows, rows, "row count mismatch");
|
||||
assert_eq!(self.cols, cols, "col count mismatch");
|
||||
self.data.clone()
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates element-wise eq, lt, le, gt and ge methods
|
||||
/// where the rhs can be a Matrix<T> or a scalar T.
|
||||
macro_rules! impl_elementwise_cmp {
|
||||
(
|
||||
$( $method:ident => $op:tt ),* $(,)?
|
||||
) => {
|
||||
impl<T: PartialOrd> Matrix<T> {
|
||||
impl<T: PartialOrd + Clone> Matrix<T> {
|
||||
$(
|
||||
#[doc = concat!("Element‐wise comparison `x ", stringify!($op), " y`.")]
|
||||
pub fn $method(&self, rhs: &Matrix<T>) -> BoolMatrix {
|
||||
assert_eq!(self.rows, rhs.rows, "row count mismatch");
|
||||
assert_eq!(self.cols, rhs.cols, "col count mismatch");
|
||||
#[doc = concat!("Element-wise comparison `self ", stringify!($op), " rhs`,\n\
|
||||
where `rhs` may be a Matrix<T> or a scalar T.")]
|
||||
pub fn $method<Rhs>(&self, rhs: Rhs) -> BoolMatrix
|
||||
where
|
||||
Rhs: Broadcastable<T>,
|
||||
{
|
||||
// Prepare broadcasted rhs-data
|
||||
let rhs_data = rhs.to_vec(self.rows, self.cols);
|
||||
|
||||
// Pairwise compare
|
||||
let data = self
|
||||
.data
|
||||
.iter()
|
||||
.zip(&rhs.data)
|
||||
.cloned()
|
||||
.zip(rhs_data.into_iter())
|
||||
.map(|(a, b)| a $op b)
|
||||
.collect();
|
||||
|
||||
BoolMatrix::from_vec(data, self.rows, self.cols)
|
||||
}
|
||||
)*
|
||||
@ -208,7 +236,7 @@ macro_rules! impl_elementwise_cmp {
|
||||
};
|
||||
}
|
||||
|
||||
// Invoke it for all five operations:
|
||||
// Instantiate element-wise comparison implementations for matrices.
|
||||
impl_elementwise_cmp! {
|
||||
eq_elementwise => ==,
|
||||
lt_elementwise => <,
|
||||
|
Loading…
x
Reference in New Issue
Block a user