mirror of
https://github.com/Magnus167/rustframe.git
synced 2025-08-20 06:30:00 +00:00
Implement element-wise arithmetic operations for matrices with scalars and add corresponding unit tests
This commit is contained in:
parent
d6ccf88829
commit
d0b9ab0716
@ -297,6 +297,30 @@ impl_elementwise_op!(Sub, sub, -);
|
|||||||
impl_elementwise_op!(Mul, mul, *);
|
impl_elementwise_op!(Mul, mul, *);
|
||||||
impl_elementwise_op!(Div, div, /);
|
impl_elementwise_op!(Div, div, /);
|
||||||
|
|
||||||
|
/// Generates element-wise arithmetic implementations for matrices with scalars.
|
||||||
|
macro_rules! impl_elementwise_op_scalar {
|
||||||
|
($OpTrait:ident, $method:ident, $op:tt) => {
|
||||||
|
impl<'a, T> std::ops::$OpTrait<T> for &'a Matrix<T>
|
||||||
|
where
|
||||||
|
T: Clone + std::ops::$OpTrait<Output = T>,
|
||||||
|
{
|
||||||
|
type Output = Matrix<T>;
|
||||||
|
|
||||||
|
fn $method(self, rhs: T) -> Matrix<T> {
|
||||||
|
// Apply the operation element-wise and collect into a new matrix
|
||||||
|
let data = self.data.iter().cloned().map(|a| a $op rhs.clone()).collect();
|
||||||
|
Matrix { rows: self.rows, cols: self.cols, data }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Instantiate element-wise addition, subtraction, multiplication, and division
|
||||||
|
impl_elementwise_op_scalar!(Add, add, +);
|
||||||
|
impl_elementwise_op_scalar!(Sub, sub, -);
|
||||||
|
impl_elementwise_op_scalar!(Mul, mul, *);
|
||||||
|
impl_elementwise_op_scalar!(Div, div, /);
|
||||||
|
|
||||||
/// Generates element-wise bitwise operations for boolean matrices.
|
/// Generates element-wise bitwise operations for boolean matrices.
|
||||||
macro_rules! impl_bitwise_op {
|
macro_rules! impl_bitwise_op {
|
||||||
($OpTrait:ident, $method:ident, $op:tt) => {
|
($OpTrait:ident, $method:ident, $op:tt) => {
|
||||||
@ -1347,4 +1371,25 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_arithmetic_broadcast() {
|
||||||
|
let matrix = create_test_matrix();
|
||||||
|
let result = &matrix + 1;
|
||||||
|
for i in 0..matrix.rows() {
|
||||||
|
for j in 0..matrix.cols() {
|
||||||
|
assert_eq!(result[(i, j)], matrix[(i, j)] + 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// test mul and div
|
||||||
|
let result = &matrix * 2;
|
||||||
|
let result2 = &matrix / 2;
|
||||||
|
for i in 0..matrix.rows() {
|
||||||
|
for j in 0..matrix.cols() {
|
||||||
|
assert_eq!(result[(i, j)], matrix[(i, j)] * 2);
|
||||||
|
assert_eq!(result2[(i, j)], matrix[(i, j)] / 2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user