mirror of
https://github.com/Magnus167/rustframe.git
synced 2025-08-20 04:00:01 +00:00
Enhance element-wise arithmetic operations for Matrix to support multiple ownership variants and improve in-place modifications
This commit is contained in:
parent
a30a7101e8
commit
7f587a7b7e
@ -407,39 +407,73 @@ impl_elementwise_cmp! {
|
|||||||
ge_elem => >=,
|
ge_elem => >=,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates element-wise arithmetic implementations for matrices.
|
/// Generates element-wise arithmetic implementations for Matrix + Matrix
|
||||||
macro_rules! impl_elementwise_op {
|
macro_rules! impl_elementwise_op_matrix_all {
|
||||||
($OpTrait:ident, $method:ident, $op:tt) => {
|
($OpTrait:ident, $method:ident, $op:tt) => {
|
||||||
impl<'a, 'b, T> std::ops::$OpTrait<&'b Matrix<T>> for &'a Matrix<T>
|
// &Matrix + &Matrix
|
||||||
where
|
impl<'a, 'b, T> $OpTrait<&'b Matrix<T>> for &'a Matrix<T>
|
||||||
T: Clone + std::ops::$OpTrait<Output = T>,
|
where T: Clone + $OpTrait<Output = T> {
|
||||||
{
|
|
||||||
type Output = Matrix<T>;
|
type Output = Matrix<T>;
|
||||||
|
|
||||||
fn $method(self, rhs: &'b Matrix<T>) -> Matrix<T> {
|
fn $method(self, rhs: &'b Matrix<T>) -> Matrix<T> {
|
||||||
// Ensure both matrices have identical dimensions
|
assert_eq!(self.rows, rhs.rows, "row count mismatch"); assert_eq!(self.cols, rhs.cols, "col count mismatch");
|
||||||
assert_eq!(self.rows, rhs.rows, "row count mismatch");
|
let data = self.data.iter().cloned().zip(rhs.data.iter().cloned()).map(|(a, b)| a $op b).collect();
|
||||||
assert_eq!(self.cols, rhs.cols, "col count mismatch");
|
|
||||||
// Apply the operation element-wise and collect into a new matrix
|
|
||||||
let data = self
|
|
||||||
.data
|
|
||||||
.iter()
|
|
||||||
.cloned()
|
|
||||||
.zip(rhs.data.iter().cloned())
|
|
||||||
.map(|(a, b)| a $op b)
|
|
||||||
.collect();
|
|
||||||
Matrix { rows: self.rows, cols: self.cols, data }
|
Matrix { rows: self.rows, cols: self.cols, data }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Matrix + &Matrix (Consumes self)
|
||||||
|
impl<'b, T> $OpTrait<&'b Matrix<T>> for Matrix<T>
|
||||||
|
where T: Clone + $OpTrait<Output = T> {
|
||||||
|
type Output = Matrix<T>;
|
||||||
|
fn $method(mut self, rhs: &'b Matrix<T>) -> Matrix<T> { // Make self mutable for potential in-place modification
|
||||||
|
assert_eq!(self.rows, rhs.rows, "row count mismatch"); assert_eq!(self.cols, rhs.cols, "col count mismatch");
|
||||||
|
// Modify data in place
|
||||||
|
for (a, b) in self.data.iter_mut().zip(rhs.data.iter().cloned()) {
|
||||||
|
*a = a.clone() $op b; // Requires T: Clone for the *a = part
|
||||||
|
}
|
||||||
|
// Return modified self (its data vec was consumed conceptually)
|
||||||
|
self
|
||||||
|
// Alternative: Collect into new Vec if T is not Clone or in-place is complex
|
||||||
|
// let data = self.data.into_iter().zip(rhs.data.iter().cloned()).map(|(a, b)| a $op b).collect();
|
||||||
|
// Matrix { rows: self.rows, cols: self.cols, data }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// &Matrix + Matrix (Consumes rhs)
|
||||||
|
impl<'a, T> $OpTrait<Matrix<T>> for &'a Matrix<T>
|
||||||
|
where T: Clone + $OpTrait<Output = T> {
|
||||||
|
type Output = Matrix<T>;
|
||||||
|
fn $method(self, mut rhs: Matrix<T>) -> Matrix<T> { // Make rhs mutable
|
||||||
|
assert_eq!(self.rows, rhs.rows, "row count mismatch"); assert_eq!(self.cols, rhs.cols, "col count mismatch");
|
||||||
|
// Modify rhs data in place
|
||||||
|
for (a, b) in self.data.iter().cloned().zip(rhs.data.iter_mut()) {
|
||||||
|
*b = a $op b.clone(); // Requires T: Clone for the *b = part
|
||||||
|
}
|
||||||
|
// Return modified rhs
|
||||||
|
rhs
|
||||||
|
// Alternative: Collect into new Vec
|
||||||
|
// let data = self.data.iter().cloned().zip(rhs.data.into_iter()).map(|(a, b)| a $op b).collect();
|
||||||
|
// Matrix { rows: self.rows, cols: self.cols, data }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Matrix + Matrix (Consumes both)
|
||||||
|
impl<T> $OpTrait<Matrix<T>> for Matrix<T>
|
||||||
|
where T: Clone + $OpTrait<Output = T> {
|
||||||
|
type Output = Matrix<T>;
|
||||||
|
fn $method(mut self, rhs: Matrix<T>) -> Matrix<T> { // Make self mutable
|
||||||
|
assert_eq!(self.rows, rhs.rows, "row count mismatch"); assert_eq!(self.cols, rhs.cols, "col count mismatch");
|
||||||
|
// Modify self data in place
|
||||||
|
for (a, b) in self.data.iter_mut().zip(rhs.data.into_iter()) {
|
||||||
|
*a = a.clone() $op b; // Requires T: Clone
|
||||||
|
}
|
||||||
|
// Return modified self
|
||||||
|
self
|
||||||
|
// Alternative: Collect into new Vec
|
||||||
|
// let data = self.data.into_iter().zip(rhs.data.into_iter()).map(|(a, b)| a $op b).collect();
|
||||||
|
// Matrix { rows: self.rows, cols: self.cols, data }
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
// Instantiate element-wise addition, subtraction, multiplication, and division
|
|
||||||
impl_elementwise_op!(Add, add, +);
|
|
||||||
impl_elementwise_op!(Sub, sub, -);
|
|
||||||
impl_elementwise_op!(Mul, mul, *);
|
|
||||||
impl_elementwise_op!(Div, div, /);
|
|
||||||
|
|
||||||
/// Generates element-wise arithmetic implementations for matrices with scalars.
|
/// Generates element-wise arithmetic implementations for matrices with scalars.
|
||||||
macro_rules! impl_elementwise_op_scalar {
|
macro_rules! impl_elementwise_op_scalar {
|
||||||
($OpTrait:ident, $method:ident, $op:tt) => {
|
($OpTrait:ident, $method:ident, $op:tt) => {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user