mirror of
https://github.com/Magnus167/rustframe.git
synced 2025-08-20 04:00:01 +00:00
Enhance element-wise arithmetic operations for Matrix to support multiple ownership variants and improve in-place modifications
This commit is contained in:
parent
a30a7101e8
commit
7f587a7b7e
@ -407,39 +407,73 @@ impl_elementwise_cmp! {
|
||||
ge_elem => >=,
|
||||
}
|
||||
|
||||
/// Generates element-wise arithmetic implementations for matrices.
|
||||
macro_rules! impl_elementwise_op {
|
||||
/// Generates element-wise arithmetic implementations for Matrix + Matrix
|
||||
macro_rules! impl_elementwise_op_matrix_all {
|
||||
($OpTrait:ident, $method:ident, $op:tt) => {
|
||||
impl<'a, 'b, T> std::ops::$OpTrait<&'b Matrix<T>> for &'a Matrix<T>
|
||||
where
|
||||
T: Clone + std::ops::$OpTrait<Output = T>,
|
||||
{
|
||||
// &Matrix + &Matrix
|
||||
impl<'a, 'b, T> $OpTrait<&'b Matrix<T>> for &'a Matrix<T>
|
||||
where T: Clone + $OpTrait<Output = T> {
|
||||
type Output = Matrix<T>;
|
||||
|
||||
fn $method(self, rhs: &'b Matrix<T>) -> Matrix<T> {
|
||||
// Ensure both matrices have identical dimensions
|
||||
assert_eq!(self.rows, rhs.rows, "row count mismatch");
|
||||
assert_eq!(self.cols, rhs.cols, "col count mismatch");
|
||||
// Apply the operation element-wise and collect into a new matrix
|
||||
let data = self
|
||||
.data
|
||||
.iter()
|
||||
.cloned()
|
||||
.zip(rhs.data.iter().cloned())
|
||||
.map(|(a, b)| a $op b)
|
||||
.collect();
|
||||
assert_eq!(self.rows, rhs.rows, "row count mismatch"); assert_eq!(self.cols, rhs.cols, "col count mismatch");
|
||||
let data = self.data.iter().cloned().zip(rhs.data.iter().cloned()).map(|(a, b)| a $op b).collect();
|
||||
Matrix { rows: self.rows, cols: self.cols, data }
|
||||
}
|
||||
}
|
||||
// Matrix + &Matrix (Consumes self)
|
||||
impl<'b, T> $OpTrait<&'b Matrix<T>> for Matrix<T>
|
||||
where T: Clone + $OpTrait<Output = T> {
|
||||
type Output = Matrix<T>;
|
||||
fn $method(mut self, rhs: &'b Matrix<T>) -> Matrix<T> { // Make self mutable for potential in-place modification
|
||||
assert_eq!(self.rows, rhs.rows, "row count mismatch"); assert_eq!(self.cols, rhs.cols, "col count mismatch");
|
||||
// Modify data in place
|
||||
for (a, b) in self.data.iter_mut().zip(rhs.data.iter().cloned()) {
|
||||
*a = a.clone() $op b; // Requires T: Clone for the *a = part
|
||||
}
|
||||
// Return modified self (its data vec was consumed conceptually)
|
||||
self
|
||||
// Alternative: Collect into new Vec if T is not Clone or in-place is complex
|
||||
// let data = self.data.into_iter().zip(rhs.data.iter().cloned()).map(|(a, b)| a $op b).collect();
|
||||
// Matrix { rows: self.rows, cols: self.cols, data }
|
||||
}
|
||||
}
|
||||
// &Matrix + Matrix (Consumes rhs)
|
||||
impl<'a, T> $OpTrait<Matrix<T>> for &'a Matrix<T>
|
||||
where T: Clone + $OpTrait<Output = T> {
|
||||
type Output = Matrix<T>;
|
||||
fn $method(self, mut rhs: Matrix<T>) -> Matrix<T> { // Make rhs mutable
|
||||
assert_eq!(self.rows, rhs.rows, "row count mismatch"); assert_eq!(self.cols, rhs.cols, "col count mismatch");
|
||||
// Modify rhs data in place
|
||||
for (a, b) in self.data.iter().cloned().zip(rhs.data.iter_mut()) {
|
||||
*b = a $op b.clone(); // Requires T: Clone for the *b = part
|
||||
}
|
||||
// Return modified rhs
|
||||
rhs
|
||||
// Alternative: Collect into new Vec
|
||||
// let data = self.data.iter().cloned().zip(rhs.data.into_iter()).map(|(a, b)| a $op b).collect();
|
||||
// Matrix { rows: self.rows, cols: self.cols, data }
|
||||
}
|
||||
}
|
||||
// Matrix + Matrix (Consumes both)
|
||||
impl<T> $OpTrait<Matrix<T>> for Matrix<T>
|
||||
where T: Clone + $OpTrait<Output = T> {
|
||||
type Output = Matrix<T>;
|
||||
fn $method(mut self, rhs: Matrix<T>) -> Matrix<T> { // Make self mutable
|
||||
assert_eq!(self.rows, rhs.rows, "row count mismatch"); assert_eq!(self.cols, rhs.cols, "col count mismatch");
|
||||
// Modify self data in place
|
||||
for (a, b) in self.data.iter_mut().zip(rhs.data.into_iter()) {
|
||||
*a = a.clone() $op b; // Requires T: Clone
|
||||
}
|
||||
// Return modified self
|
||||
self
|
||||
// Alternative: Collect into new Vec
|
||||
// let data = self.data.into_iter().zip(rhs.data.into_iter()).map(|(a, b)| a $op b).collect();
|
||||
// Matrix { rows: self.rows, cols: self.cols, data }
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// Instantiate element-wise addition, subtraction, multiplication, and division
|
||||
impl_elementwise_op!(Add, add, +);
|
||||
impl_elementwise_op!(Sub, sub, -);
|
||||
impl_elementwise_op!(Mul, mul, *);
|
||||
impl_elementwise_op!(Div, div, /);
|
||||
|
||||
/// Generates element-wise arithmetic implementations for matrices with scalars.
|
||||
macro_rules! impl_elementwise_op_scalar {
|
||||
($OpTrait:ident, $method:ident, $op:tt) => {
|
||||
|
Loading…
x
Reference in New Issue
Block a user