From d0b9ab0716c00f34aef6adbcd8a1e6b6fe4535fb Mon Sep 17 00:00:00 2001 From: Palash Tyagi <23239946+Magnus167@users.noreply.github.com> Date: Sat, 26 Apr 2025 02:49:43 +0100 Subject: [PATCH] Implement element-wise arithmetic operations for matrices with scalars and add corresponding unit tests --- src/matrix/mat.rs | 45 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/src/matrix/mat.rs b/src/matrix/mat.rs index 95f0ac7..9f0cb09 100644 --- a/src/matrix/mat.rs +++ b/src/matrix/mat.rs @@ -297,6 +297,30 @@ impl_elementwise_op!(Sub, sub, -); impl_elementwise_op!(Mul, mul, *); impl_elementwise_op!(Div, div, /); +/// Generates element-wise arithmetic implementations for matrices with scalars. +macro_rules! impl_elementwise_op_scalar { + ($OpTrait:ident, $method:ident, $op:tt) => { + impl<'a, T> std::ops::$OpTrait for &'a Matrix + where + T: Clone + std::ops::$OpTrait, + { + type Output = Matrix; + + fn $method(self, rhs: T) -> Matrix { + // Apply the operation element-wise and collect into a new matrix + let data = self.data.iter().cloned().map(|a| a $op rhs.clone()).collect(); + Matrix { rows: self.rows, cols: self.cols, data } + } + } + }; +} + +// Instantiate element-wise addition, subtraction, multiplication, and division +impl_elementwise_op_scalar!(Add, add, +); +impl_elementwise_op_scalar!(Sub, sub, -); +impl_elementwise_op_scalar!(Mul, mul, *); +impl_elementwise_op_scalar!(Div, div, /); + /// Generates element-wise bitwise operations for boolean matrices. macro_rules! impl_bitwise_op { ($OpTrait:ident, $method:ident, $op:tt) => { @@ -1347,4 +1371,25 @@ mod tests { } } } + + #[test] + fn test_arithmetic_broadcast() { + let matrix = create_test_matrix(); + let result = &matrix + 1; + for i in 0..matrix.rows() { + for j in 0..matrix.cols() { + assert_eq!(result[(i, j)], matrix[(i, j)] + 1); + } + } + + // test mul and div + let result = &matrix * 2; + let result2 = &matrix / 2; + for i in 0..matrix.rows() { + for j in 0..matrix.cols() { + assert_eq!(result[(i, j)], matrix[(i, j)] * 2); + assert_eq!(result2[(i, j)], matrix[(i, j)] / 2); + } + } + } }