mirror of
https://github.com/Magnus167/rustframe.git
synced 2025-08-20 04:19:59 +00:00
Implement broadcasting trait for matrices and enhance element-wise comparison methods
This commit is contained in:
parent
21fc4a7241
commit
56bb579530
@ -183,24 +183,52 @@ impl<'a, T> MatrixRow<'a, T> {
|
|||||||
(0..self.matrix.cols).map(move |c| &self.matrix[(self.row, c)])
|
(0..self.matrix.cols).map(move |c| &self.matrix[(self.row, c)])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
/// A trait to turn either a Matrix<T> or a scalar T into a Vec<T> of
|
||||||
|
/// length rows*cols (broadcasting the scalar).
|
||||||
|
pub trait Broadcastable<T> {
|
||||||
|
fn to_vec(&self, rows: usize, cols: usize) -> Vec<T>;
|
||||||
|
}
|
||||||
|
|
||||||
/// Implements element‐wise eq, lt, le, gt and ge for `Matrix<T>`.
|
impl<T: Clone> Broadcastable<T> for T {
|
||||||
|
fn to_vec(&self, rows: usize, cols: usize) -> Vec<T> {
|
||||||
|
vec![self.clone(); rows * cols]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Clone> Broadcastable<T> for Matrix<T> {
|
||||||
|
fn to_vec(&self, rows: usize, cols: usize) -> Vec<T> {
|
||||||
|
assert_eq!(self.rows, rows, "row count mismatch");
|
||||||
|
assert_eq!(self.cols, cols, "col count mismatch");
|
||||||
|
self.data.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates element-wise eq, lt, le, gt and ge methods
|
||||||
|
/// where the rhs can be a Matrix<T> or a scalar T.
|
||||||
macro_rules! impl_elementwise_cmp {
|
macro_rules! impl_elementwise_cmp {
|
||||||
(
|
(
|
||||||
$( $method:ident => $op:tt ),* $(,)?
|
$( $method:ident => $op:tt ),* $(,)?
|
||||||
) => {
|
) => {
|
||||||
impl<T: PartialOrd> Matrix<T> {
|
impl<T: PartialOrd + Clone> Matrix<T> {
|
||||||
$(
|
$(
|
||||||
#[doc = concat!("Element‐wise comparison `x ", stringify!($op), " y`.")]
|
#[doc = concat!("Element-wise comparison `self ", stringify!($op), " rhs`,\n\
|
||||||
pub fn $method(&self, rhs: &Matrix<T>) -> BoolMatrix {
|
where `rhs` may be a Matrix<T> or a scalar T.")]
|
||||||
assert_eq!(self.rows, rhs.rows, "row count mismatch");
|
pub fn $method<Rhs>(&self, rhs: Rhs) -> BoolMatrix
|
||||||
assert_eq!(self.cols, rhs.cols, "col count mismatch");
|
where
|
||||||
|
Rhs: Broadcastable<T>,
|
||||||
|
{
|
||||||
|
// Prepare broadcasted rhs-data
|
||||||
|
let rhs_data = rhs.to_vec(self.rows, self.cols);
|
||||||
|
|
||||||
|
// Pairwise compare
|
||||||
let data = self
|
let data = self
|
||||||
.data
|
.data
|
||||||
.iter()
|
.iter()
|
||||||
.zip(&rhs.data)
|
.cloned()
|
||||||
|
.zip(rhs_data.into_iter())
|
||||||
.map(|(a, b)| a $op b)
|
.map(|(a, b)| a $op b)
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
BoolMatrix::from_vec(data, self.rows, self.cols)
|
BoolMatrix::from_vec(data, self.rows, self.cols)
|
||||||
}
|
}
|
||||||
)*
|
)*
|
||||||
@ -208,7 +236,7 @@ macro_rules! impl_elementwise_cmp {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
// Invoke it for all five operations:
|
// Instantiate element-wise comparison implementations for matrices.
|
||||||
impl_elementwise_cmp! {
|
impl_elementwise_cmp! {
|
||||||
eq_elementwise => ==,
|
eq_elementwise => ==,
|
||||||
lt_elementwise => <,
|
lt_elementwise => <,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user