mirror of
https://github.com/Magnus167/rustframe.git
synced 2025-08-20 04:19:59 +00:00
Refactor element-wise arithmetic operations to use a centralized dimension check for improved error handling
This commit is contained in:
parent
ca734fbedf
commit
3cb68be062
@ -407,6 +407,23 @@ impl_elementwise_cmp! {
|
||||
ge_elem => >=,
|
||||
}
|
||||
|
||||
// --- Element-wise Arithmetic Operations (Macros generating all ownership variants) ---
|
||||
|
||||
fn check_matrix_dims_for_ops<T>(lhs: &Matrix<T>, rhs: &Matrix<T>) {
|
||||
assert!(
|
||||
lhs.rows == rhs.rows,
|
||||
"Row count mismatch: left has {} rows, right has {} rows",
|
||||
lhs.rows,
|
||||
rhs.rows
|
||||
);
|
||||
assert!(
|
||||
lhs.cols == rhs.cols,
|
||||
"Column count mismatch: left has {} columns, right has {} columns",
|
||||
lhs.cols,
|
||||
rhs.cols
|
||||
);
|
||||
}
|
||||
|
||||
/// Generates element-wise arithmetic implementations for Matrix + Matrix
|
||||
macro_rules! impl_elementwise_op_matrix_all {
|
||||
($OpTrait:ident, $method:ident, $op:tt) => {
|
||||
@ -415,7 +432,7 @@ macro_rules! impl_elementwise_op_matrix_all {
|
||||
where T: Clone + $OpTrait<Output = T> {
|
||||
type Output = Matrix<T>;
|
||||
fn $method(self, rhs: &'b Matrix<T>) -> Matrix<T> {
|
||||
assert_eq!(self.rows, rhs.rows, "row count mismatch"); assert_eq!(self.cols, rhs.cols, "col count mismatch");
|
||||
check_matrix_dims_for_ops(self, rhs);
|
||||
let data = self.data.iter().cloned().zip(rhs.data.iter().cloned()).map(|(a, b)| a $op b).collect();
|
||||
Matrix { rows: self.rows, cols: self.cols, data }
|
||||
}
|
||||
@ -425,7 +442,7 @@ macro_rules! impl_elementwise_op_matrix_all {
|
||||
where T: Clone + $OpTrait<Output = T> {
|
||||
type Output = Matrix<T>;
|
||||
fn $method(mut self, rhs: &'b Matrix<T>) -> Matrix<T> { // Make self mutable for potential in-place modification
|
||||
assert_eq!(self.rows, rhs.rows, "row count mismatch"); assert_eq!(self.cols, rhs.cols, "col count mismatch");
|
||||
check_matrix_dims_for_ops(&self, rhs);
|
||||
// Modify data in place
|
||||
for (a, b) in self.data.iter_mut().zip(rhs.data.iter().cloned()) {
|
||||
*a = a.clone() $op b; // Requires T: Clone for the *a = part
|
||||
@ -442,7 +459,7 @@ macro_rules! impl_elementwise_op_matrix_all {
|
||||
where T: Clone + $OpTrait<Output = T> {
|
||||
type Output = Matrix<T>;
|
||||
fn $method(self, mut rhs: Matrix<T>) -> Matrix<T> { // Make rhs mutable
|
||||
assert_eq!(self.rows, rhs.rows, "row count mismatch"); assert_eq!(self.cols, rhs.cols, "col count mismatch");
|
||||
check_matrix_dims_for_ops(self, &rhs);
|
||||
// Modify rhs data in place
|
||||
for (a, b) in self.data.iter().cloned().zip(rhs.data.iter_mut()) {
|
||||
*b = a $op b.clone(); // Requires T: Clone for the *b = part
|
||||
@ -459,7 +476,7 @@ macro_rules! impl_elementwise_op_matrix_all {
|
||||
where T: Clone + $OpTrait<Output = T> {
|
||||
type Output = Matrix<T>;
|
||||
fn $method(mut self, rhs: Matrix<T>) -> Matrix<T> { // Make self mutable
|
||||
assert_eq!(self.rows, rhs.rows, "row count mismatch"); assert_eq!(self.cols, rhs.cols, "col count mismatch");
|
||||
check_matrix_dims_for_ops(&self, &rhs);
|
||||
// Modify self data in place
|
||||
for (a, b) in self.data.iter_mut().zip(rhs.data.into_iter()) {
|
||||
*a = a.clone() $op b; // Requires T: Clone
|
||||
@ -519,14 +536,15 @@ impl_elementwise_op_scalar_all!(Sub, sub, -);
|
||||
impl_elementwise_op_scalar_all!(Mul, mul, *);
|
||||
impl_elementwise_op_scalar_all!(Div, div, /);
|
||||
|
||||
/// Generates element-wise bitwise operations for boolean matrices.
|
||||
// --- Element-wise Bitwise Operations (BoolMatrix) ---
|
||||
|
||||
macro_rules! impl_bitwise_op_all {
|
||||
($OpTrait:ident, $method:ident, $op:tt) => {
|
||||
// &Matrix<bool> OP &Matrix<bool>
|
||||
impl<'a, 'b> $OpTrait<&'b Matrix<bool>> for &'a Matrix<bool> {
|
||||
type Output = Matrix<bool>;
|
||||
fn $method(self, rhs: &'b Matrix<bool>) -> Matrix<bool> {
|
||||
assert_eq!(self.rows, rhs.rows); assert_eq!(self.cols, rhs.cols);
|
||||
check_matrix_dims_for_ops(self, rhs);
|
||||
let data = self.data.iter().cloned().zip(rhs.data.iter().cloned()).map(|(a, b)| a $op b).collect();
|
||||
Matrix { rows: self.rows, cols: self.cols, data }
|
||||
}
|
||||
@ -535,17 +553,17 @@ macro_rules! impl_bitwise_op_all {
|
||||
impl<'b> $OpTrait<&'b Matrix<bool>> for Matrix<bool> {
|
||||
type Output = Matrix<bool>;
|
||||
fn $method(mut self, rhs: &'b Matrix<bool>) -> Matrix<bool> {
|
||||
assert_eq!(self.rows, rhs.rows); assert_eq!(self.cols, rhs.cols);
|
||||
check_matrix_dims_for_ops(&self, rhs);
|
||||
for (a, b) in self.data.iter_mut().zip(rhs.data.iter()) { *a = *a $op *b; } // bool is Copy
|
||||
self
|
||||
}
|
||||
}
|
||||
// &Matrix<bool> OP Matrix<bool>
|
||||
impl<'a> $OpTrait<Matrix<bool>> for &'a Matrix<bool> {
|
||||
impl<'a> $OpTrait<Matrix<bool>> for &'a Matrix<bool> {
|
||||
type Output = Matrix<bool>;
|
||||
fn $method(self, mut rhs: Matrix<bool>) -> Matrix<bool> {
|
||||
assert_eq!(self.rows, rhs.rows); assert_eq!(self.cols, rhs.cols);
|
||||
for (a, b) in self.data.iter().zip(rhs.data.iter_mut()) { *b = *a $op *b; } // bool is Copy
|
||||
check_matrix_dims_for_ops(self, &rhs);
|
||||
for (a, b) in self.data.iter().zip(rhs.data.iter_mut()) { *b = *a $op *b; } // bool is Copy
|
||||
rhs
|
||||
}
|
||||
}
|
||||
@ -553,7 +571,7 @@ macro_rules! impl_bitwise_op_all {
|
||||
impl $OpTrait<Matrix<bool>> for Matrix<bool> {
|
||||
type Output = Matrix<bool>;
|
||||
fn $method(mut self, rhs: Matrix<bool>) -> Matrix<bool> {
|
||||
assert_eq!(self.rows, rhs.rows); assert_eq!(self.cols, rhs.cols);
|
||||
check_matrix_dims_for_ops(&self, &rhs);
|
||||
for (a, b) in self.data.iter_mut().zip(rhs.data.iter()) { *a = *a $op *b; } // bool is Copy
|
||||
self
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user