Implement element-wise arithmetic operations for matrices with scalars and add corresponding unit tests

This commit is contained in:
Palash Tyagi 2025-04-26 02:49:43 +01:00
parent d6ccf88829
commit d0b9ab0716

View File

@ -297,6 +297,30 @@ impl_elementwise_op!(Sub, sub, -);
impl_elementwise_op!(Mul, mul, *);
impl_elementwise_op!(Div, div, /);
/// Generates element-wise arithmetic implementations for matrices with scalars.
macro_rules! impl_elementwise_op_scalar {
($OpTrait:ident, $method:ident, $op:tt) => {
impl<'a, T> std::ops::$OpTrait<T> for &'a Matrix<T>
where
T: Clone + std::ops::$OpTrait<Output = T>,
{
type Output = Matrix<T>;
fn $method(self, rhs: T) -> Matrix<T> {
// Apply the operation element-wise and collect into a new matrix
let data = self.data.iter().cloned().map(|a| a $op rhs.clone()).collect();
Matrix { rows: self.rows, cols: self.cols, data }
}
}
};
}
// Instantiate element-wise addition, subtraction, multiplication, and division
impl_elementwise_op_scalar!(Add, add, +);
impl_elementwise_op_scalar!(Sub, sub, -);
impl_elementwise_op_scalar!(Mul, mul, *);
impl_elementwise_op_scalar!(Div, div, /);
/// Generates element-wise bitwise operations for boolean matrices.
macro_rules! impl_bitwise_op {
($OpTrait:ident, $method:ident, $op:tt) => {
@ -1347,4 +1371,25 @@ mod tests {
}
}
}
#[test]
fn test_arithmetic_broadcast() {
let matrix = create_test_matrix();
let result = &matrix + 1;
for i in 0..matrix.rows() {
for j in 0..matrix.cols() {
assert_eq!(result[(i, j)], matrix[(i, j)] + 1);
}
}
// test mul and div
let result = &matrix * 2;
let result2 = &matrix / 2;
for i in 0..matrix.rows() {
for j in 0..matrix.cols() {
assert_eq!(result[(i, j)], matrix[(i, j)] * 2);
assert_eq!(result2[(i, j)], matrix[(i, j)] / 2);
}
}
}
}