Enhance element-wise arithmetic operations for Matrix + Scalar to support in-place modifications and improve performance

This commit is contained in:
Palash Tyagi 2025-05-03 00:38:07 +01:00
parent 7f587a7b7e
commit 36a0846efa

View File

@ -474,29 +474,50 @@ macro_rules! impl_elementwise_op_matrix_all {
};
}
/// Generates element-wise arithmetic implementations for matrices with scalars.
macro_rules! impl_elementwise_op_scalar {
/// Generates element-wise arithmetic implementations for Matrix + Scalar
macro_rules! impl_elementwise_op_scalar_all {
($OpTrait:ident, $method:ident, $op:tt) => {
impl<'a, T> std::ops::$OpTrait<T> for &'a Matrix<T>
where
T: Clone + std::ops::$OpTrait<Output = T>,
{
// &Matrix + Scalar
impl<'a, T> $OpTrait<T> for &'a Matrix<T>
where T: Clone + $OpTrait<Output = T> {
type Output = Matrix<T>;
fn $method(self, rhs: T) -> Matrix<T> {
// Apply the operation element-wise and collect into a new matrix
let data = self.data.iter().cloned().map(|a| a $op rhs.clone()).collect();
Matrix { rows: self.rows, cols: self.cols, data }
}
}
// Matrix + Scalar (Consumes self)
impl<T> $OpTrait<T> for Matrix<T>
where T: Clone + $OpTrait<Output = T> {
type Output = Matrix<T>;
fn $method(mut self, rhs: T) -> Matrix<T> { // Make self mutable
// Modify self data in place
for a in self.data.iter_mut() {
*a = a.clone() $op rhs.clone(); // Requires T: Clone
}
// Return modified self
self
// Alternative: Collect into new Vec
// let data = self.data.into_iter().map(|a| a $op rhs.clone()).collect();
// Matrix { rows: self.rows, cols: self.cols, data }
}
}
// NOTE: Scalar + Matrix (e.g., 1.0 + matrix) is NOT implemented here.
// It would require `impl Add<Matrix<T>> for T`, which is discouraged
// for primitive types unless inside the crate defining T.
};
}
// Instantiate element-wise addition, subtraction, multiplication, and division
impl_elementwise_op_scalar!(Add, add, +);
impl_elementwise_op_scalar!(Sub, sub, -);
impl_elementwise_op_scalar!(Mul, mul, *);
impl_elementwise_op_scalar!(Div, div, /);
// Instantiate ALL combinations for arithmetic ops using the new macros
impl_elementwise_op_matrix_all!(Add, add, +);
impl_elementwise_op_matrix_all!(Sub, sub, -);
impl_elementwise_op_matrix_all!(Mul, mul, *); // Element-wise multiplication
impl_elementwise_op_matrix_all!(Div, div, /); // Element-wise division
impl_elementwise_op_scalar_all!(Add, add, +);
impl_elementwise_op_scalar_all!(Sub, sub, -);
impl_elementwise_op_scalar_all!(Mul, mul, *);
impl_elementwise_op_scalar_all!(Div, div, /);
/// Generates element-wise bitwise operations for boolean matrices.
macro_rules! impl_bitwise_op {