Enhance element-wise arithmetic operations for Matrix to support multiple ownership variants and improve in-place modifications

This commit is contained in:
Palash Tyagi 2025-05-03 00:37:06 +01:00
parent a30a7101e8
commit 7f587a7b7e

View File

@ -407,39 +407,73 @@ impl_elementwise_cmp! {
ge_elem => >=, ge_elem => >=,
} }
/// Generates element-wise arithmetic implementations for matrices. /// Generates element-wise arithmetic implementations for Matrix + Matrix
macro_rules! impl_elementwise_op { macro_rules! impl_elementwise_op_matrix_all {
($OpTrait:ident, $method:ident, $op:tt) => { ($OpTrait:ident, $method:ident, $op:tt) => {
impl<'a, 'b, T> std::ops::$OpTrait<&'b Matrix<T>> for &'a Matrix<T> // &Matrix + &Matrix
where impl<'a, 'b, T> $OpTrait<&'b Matrix<T>> for &'a Matrix<T>
T: Clone + std::ops::$OpTrait<Output = T>, where T: Clone + $OpTrait<Output = T> {
{
type Output = Matrix<T>; type Output = Matrix<T>;
fn $method(self, rhs: &'b Matrix<T>) -> Matrix<T> { fn $method(self, rhs: &'b Matrix<T>) -> Matrix<T> {
// Ensure both matrices have identical dimensions assert_eq!(self.rows, rhs.rows, "row count mismatch"); assert_eq!(self.cols, rhs.cols, "col count mismatch");
assert_eq!(self.rows, rhs.rows, "row count mismatch"); let data = self.data.iter().cloned().zip(rhs.data.iter().cloned()).map(|(a, b)| a $op b).collect();
assert_eq!(self.cols, rhs.cols, "col count mismatch");
// Apply the operation element-wise and collect into a new matrix
let data = self
.data
.iter()
.cloned()
.zip(rhs.data.iter().cloned())
.map(|(a, b)| a $op b)
.collect();
Matrix { rows: self.rows, cols: self.cols, data } Matrix { rows: self.rows, cols: self.cols, data }
} }
} }
// Matrix + &Matrix (Consumes self)
impl<'b, T> $OpTrait<&'b Matrix<T>> for Matrix<T>
where T: Clone + $OpTrait<Output = T> {
type Output = Matrix<T>;
fn $method(mut self, rhs: &'b Matrix<T>) -> Matrix<T> { // Make self mutable for potential in-place modification
assert_eq!(self.rows, rhs.rows, "row count mismatch"); assert_eq!(self.cols, rhs.cols, "col count mismatch");
// Modify data in place
for (a, b) in self.data.iter_mut().zip(rhs.data.iter().cloned()) {
*a = a.clone() $op b; // Requires T: Clone for the *a = part
}
// Return modified self (its data vec was consumed conceptually)
self
// Alternative: Collect into new Vec if T is not Clone or in-place is complex
// let data = self.data.into_iter().zip(rhs.data.iter().cloned()).map(|(a, b)| a $op b).collect();
// Matrix { rows: self.rows, cols: self.cols, data }
}
}
// &Matrix + Matrix (Consumes rhs)
impl<'a, T> $OpTrait<Matrix<T>> for &'a Matrix<T>
where T: Clone + $OpTrait<Output = T> {
type Output = Matrix<T>;
fn $method(self, mut rhs: Matrix<T>) -> Matrix<T> { // Make rhs mutable
assert_eq!(self.rows, rhs.rows, "row count mismatch"); assert_eq!(self.cols, rhs.cols, "col count mismatch");
// Modify rhs data in place
for (a, b) in self.data.iter().cloned().zip(rhs.data.iter_mut()) {
*b = a $op b.clone(); // Requires T: Clone for the *b = part
}
// Return modified rhs
rhs
// Alternative: Collect into new Vec
// let data = self.data.iter().cloned().zip(rhs.data.into_iter()).map(|(a, b)| a $op b).collect();
// Matrix { rows: self.rows, cols: self.cols, data }
}
}
// Matrix + Matrix (Consumes both)
impl<T> $OpTrait<Matrix<T>> for Matrix<T>
where T: Clone + $OpTrait<Output = T> {
type Output = Matrix<T>;
fn $method(mut self, rhs: Matrix<T>) -> Matrix<T> { // Make self mutable
assert_eq!(self.rows, rhs.rows, "row count mismatch"); assert_eq!(self.cols, rhs.cols, "col count mismatch");
// Modify self data in place
for (a, b) in self.data.iter_mut().zip(rhs.data.into_iter()) {
*a = a.clone() $op b; // Requires T: Clone
}
// Return modified self
self
// Alternative: Collect into new Vec
// let data = self.data.into_iter().zip(rhs.data.into_iter()).map(|(a, b)| a $op b).collect();
// Matrix { rows: self.rows, cols: self.cols, data }
}
}
}; };
} }
// Instantiate element-wise addition, subtraction, multiplication, and division
impl_elementwise_op!(Add, add, +);
impl_elementwise_op!(Sub, sub, -);
impl_elementwise_op!(Mul, mul, *);
impl_elementwise_op!(Div, div, /);
/// Generates element-wise arithmetic implementations for matrices with scalars. /// Generates element-wise arithmetic implementations for matrices with scalars.
macro_rules! impl_elementwise_op_scalar { macro_rules! impl_elementwise_op_scalar {
($OpTrait:ident, $method:ident, $op:tt) => { ($OpTrait:ident, $method:ident, $op:tt) => {