From 7f587a7b7e2b4bccf3c027fc1043d450ff57bcb2 Mon Sep 17 00:00:00 2001 From: Palash Tyagi <23239946+Magnus167@users.noreply.github.com> Date: Sat, 3 May 2025 00:37:06 +0100 Subject: [PATCH] Enhance element-wise arithmetic operations for Matrix to support multiple ownership variants and improve in-place modifications --- src/matrix/mat.rs | 82 +++++++++++++++++++++++++++++++++-------------- 1 file changed, 58 insertions(+), 24 deletions(-) diff --git a/src/matrix/mat.rs b/src/matrix/mat.rs index 110d2d6..4fe80fe 100644 --- a/src/matrix/mat.rs +++ b/src/matrix/mat.rs @@ -407,39 +407,73 @@ impl_elementwise_cmp! { ge_elem => >=, } -/// Generates element-wise arithmetic implementations for matrices. -macro_rules! impl_elementwise_op { +/// Generates element-wise arithmetic implementations for Matrix + Matrix +macro_rules! impl_elementwise_op_matrix_all { ($OpTrait:ident, $method:ident, $op:tt) => { - impl<'a, 'b, T> std::ops::$OpTrait<&'b Matrix> for &'a Matrix - where - T: Clone + std::ops::$OpTrait, - { + // &Matrix + &Matrix + impl<'a, 'b, T> $OpTrait<&'b Matrix> for &'a Matrix + where T: Clone + $OpTrait { type Output = Matrix; - fn $method(self, rhs: &'b Matrix) -> Matrix { - // Ensure both matrices have identical dimensions - assert_eq!(self.rows, rhs.rows, "row count mismatch"); - assert_eq!(self.cols, rhs.cols, "col count mismatch"); - // Apply the operation element-wise and collect into a new matrix - let data = self - .data - .iter() - .cloned() - .zip(rhs.data.iter().cloned()) - .map(|(a, b)| a $op b) - .collect(); + assert_eq!(self.rows, rhs.rows, "row count mismatch"); assert_eq!(self.cols, rhs.cols, "col count mismatch"); + let data = self.data.iter().cloned().zip(rhs.data.iter().cloned()).map(|(a, b)| a $op b).collect(); Matrix { rows: self.rows, cols: self.cols, data } } } + // Matrix + &Matrix (Consumes self) + impl<'b, T> $OpTrait<&'b Matrix> for Matrix + where T: Clone + $OpTrait { + type Output = Matrix; + fn $method(mut self, rhs: &'b Matrix) -> Matrix { // Make self mutable for potential in-place modification + assert_eq!(self.rows, rhs.rows, "row count mismatch"); assert_eq!(self.cols, rhs.cols, "col count mismatch"); + // Modify data in place + for (a, b) in self.data.iter_mut().zip(rhs.data.iter().cloned()) { + *a = a.clone() $op b; // Requires T: Clone for the *a = part + } + // Return modified self (its data vec was consumed conceptually) + self + // Alternative: Collect into new Vec if T is not Clone or in-place is complex + // let data = self.data.into_iter().zip(rhs.data.iter().cloned()).map(|(a, b)| a $op b).collect(); + // Matrix { rows: self.rows, cols: self.cols, data } + } + } + // &Matrix + Matrix (Consumes rhs) + impl<'a, T> $OpTrait> for &'a Matrix + where T: Clone + $OpTrait { + type Output = Matrix; + fn $method(self, mut rhs: Matrix) -> Matrix { // Make rhs mutable + assert_eq!(self.rows, rhs.rows, "row count mismatch"); assert_eq!(self.cols, rhs.cols, "col count mismatch"); + // Modify rhs data in place + for (a, b) in self.data.iter().cloned().zip(rhs.data.iter_mut()) { + *b = a $op b.clone(); // Requires T: Clone for the *b = part + } + // Return modified rhs + rhs + // Alternative: Collect into new Vec + // let data = self.data.iter().cloned().zip(rhs.data.into_iter()).map(|(a, b)| a $op b).collect(); + // Matrix { rows: self.rows, cols: self.cols, data } + } + } + // Matrix + Matrix (Consumes both) + impl $OpTrait> for Matrix + where T: Clone + $OpTrait { + type Output = Matrix; + fn $method(mut self, rhs: Matrix) -> Matrix { // Make self mutable + assert_eq!(self.rows, rhs.rows, "row count mismatch"); assert_eq!(self.cols, rhs.cols, "col count mismatch"); + // Modify self data in place + for (a, b) in self.data.iter_mut().zip(rhs.data.into_iter()) { + *a = a.clone() $op b; // Requires T: Clone + } + // Return modified self + self + // Alternative: Collect into new Vec + // let data = self.data.into_iter().zip(rhs.data.into_iter()).map(|(a, b)| a $op b).collect(); + // Matrix { rows: self.rows, cols: self.cols, data } + } + } }; } -// Instantiate element-wise addition, subtraction, multiplication, and division -impl_elementwise_op!(Add, add, +); -impl_elementwise_op!(Sub, sub, -); -impl_elementwise_op!(Mul, mul, *); -impl_elementwise_op!(Div, div, /); - /// Generates element-wise arithmetic implementations for matrices with scalars. macro_rules! impl_elementwise_op_scalar { ($OpTrait:ident, $method:ident, $op:tt) => {