diff --git a/src/matrix/mat.rs b/src/matrix/mat.rs index 4fe80fe..3cb0484 100644 --- a/src/matrix/mat.rs +++ b/src/matrix/mat.rs @@ -474,29 +474,50 @@ macro_rules! impl_elementwise_op_matrix_all { }; } -/// Generates element-wise arithmetic implementations for matrices with scalars. -macro_rules! impl_elementwise_op_scalar { +/// Generates element-wise arithmetic implementations for Matrix + Scalar +macro_rules! impl_elementwise_op_scalar_all { ($OpTrait:ident, $method:ident, $op:tt) => { - impl<'a, T> std::ops::$OpTrait for &'a Matrix - where - T: Clone + std::ops::$OpTrait, - { + // &Matrix + Scalar + impl<'a, T> $OpTrait for &'a Matrix + where T: Clone + $OpTrait { type Output = Matrix; - fn $method(self, rhs: T) -> Matrix { - // Apply the operation element-wise and collect into a new matrix let data = self.data.iter().cloned().map(|a| a $op rhs.clone()).collect(); Matrix { rows: self.rows, cols: self.cols, data } } } + // Matrix + Scalar (Consumes self) + impl $OpTrait for Matrix + where T: Clone + $OpTrait { + type Output = Matrix; + fn $method(mut self, rhs: T) -> Matrix { // Make self mutable + // Modify self data in place + for a in self.data.iter_mut() { + *a = a.clone() $op rhs.clone(); // Requires T: Clone + } + // Return modified self + self + // Alternative: Collect into new Vec + // let data = self.data.into_iter().map(|a| a $op rhs.clone()).collect(); + // Matrix { rows: self.rows, cols: self.cols, data } + } + } + // NOTE: Scalar + Matrix (e.g., 1.0 + matrix) is NOT implemented here. + // It would require `impl Add> for T`, which is discouraged + // for primitive types unless inside the crate defining T. }; } -// Instantiate element-wise addition, subtraction, multiplication, and division -impl_elementwise_op_scalar!(Add, add, +); -impl_elementwise_op_scalar!(Sub, sub, -); -impl_elementwise_op_scalar!(Mul, mul, *); -impl_elementwise_op_scalar!(Div, div, /); +// Instantiate ALL combinations for arithmetic ops using the new macros +impl_elementwise_op_matrix_all!(Add, add, +); +impl_elementwise_op_matrix_all!(Sub, sub, -); +impl_elementwise_op_matrix_all!(Mul, mul, *); // Element-wise multiplication +impl_elementwise_op_matrix_all!(Div, div, /); // Element-wise division + +impl_elementwise_op_scalar_all!(Add, add, +); +impl_elementwise_op_scalar_all!(Sub, sub, -); +impl_elementwise_op_scalar_all!(Mul, mul, *); +impl_elementwise_op_scalar_all!(Div, div, /); /// Generates element-wise bitwise operations for boolean matrices. macro_rules! impl_bitwise_op {