mirror of
https://github.com/Magnus167/rustframe.git
synced 2025-08-20 04:00:01 +00:00
Implement element-wise arithmetic operations for matrices with scalars and add corresponding unit tests
This commit is contained in:
parent
d6ccf88829
commit
d0b9ab0716
@ -297,6 +297,30 @@ impl_elementwise_op!(Sub, sub, -);
|
||||
impl_elementwise_op!(Mul, mul, *);
|
||||
impl_elementwise_op!(Div, div, /);
|
||||
|
||||
/// Generates element-wise arithmetic implementations for matrices with scalars.
|
||||
macro_rules! impl_elementwise_op_scalar {
|
||||
($OpTrait:ident, $method:ident, $op:tt) => {
|
||||
impl<'a, T> std::ops::$OpTrait<T> for &'a Matrix<T>
|
||||
where
|
||||
T: Clone + std::ops::$OpTrait<Output = T>,
|
||||
{
|
||||
type Output = Matrix<T>;
|
||||
|
||||
fn $method(self, rhs: T) -> Matrix<T> {
|
||||
// Apply the operation element-wise and collect into a new matrix
|
||||
let data = self.data.iter().cloned().map(|a| a $op rhs.clone()).collect();
|
||||
Matrix { rows: self.rows, cols: self.cols, data }
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// Instantiate element-wise addition, subtraction, multiplication, and division
|
||||
impl_elementwise_op_scalar!(Add, add, +);
|
||||
impl_elementwise_op_scalar!(Sub, sub, -);
|
||||
impl_elementwise_op_scalar!(Mul, mul, *);
|
||||
impl_elementwise_op_scalar!(Div, div, /);
|
||||
|
||||
/// Generates element-wise bitwise operations for boolean matrices.
|
||||
macro_rules! impl_bitwise_op {
|
||||
($OpTrait:ident, $method:ident, $op:tt) => {
|
||||
@ -1347,4 +1371,25 @@ mod tests {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_arithmetic_broadcast() {
|
||||
let matrix = create_test_matrix();
|
||||
let result = &matrix + 1;
|
||||
for i in 0..matrix.rows() {
|
||||
for j in 0..matrix.cols() {
|
||||
assert_eq!(result[(i, j)], matrix[(i, j)] + 1);
|
||||
}
|
||||
}
|
||||
|
||||
// test mul and div
|
||||
let result = &matrix * 2;
|
||||
let result2 = &matrix / 2;
|
||||
for i in 0..matrix.rows() {
|
||||
for j in 0..matrix.cols() {
|
||||
assert_eq!(result[(i, j)], matrix[(i, j)] * 2);
|
||||
assert_eq!(result2[(i, j)], matrix[(i, j)] / 2);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user