mirror of
https://github.com/Magnus167/rustframe.git
synced 2025-08-20 08:00:00 +00:00
Refactor element-wise arithmetic operations to use a centralized dimension check for improved error handling
This commit is contained in:
parent
ca734fbedf
commit
3cb68be062
@ -407,6 +407,23 @@ impl_elementwise_cmp! {
|
|||||||
ge_elem => >=,
|
ge_elem => >=,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// --- Element-wise Arithmetic Operations (Macros generating all ownership variants) ---
|
||||||
|
|
||||||
|
fn check_matrix_dims_for_ops<T>(lhs: &Matrix<T>, rhs: &Matrix<T>) {
|
||||||
|
assert!(
|
||||||
|
lhs.rows == rhs.rows,
|
||||||
|
"Row count mismatch: left has {} rows, right has {} rows",
|
||||||
|
lhs.rows,
|
||||||
|
rhs.rows
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
lhs.cols == rhs.cols,
|
||||||
|
"Column count mismatch: left has {} columns, right has {} columns",
|
||||||
|
lhs.cols,
|
||||||
|
rhs.cols
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
/// Generates element-wise arithmetic implementations for Matrix + Matrix
|
/// Generates element-wise arithmetic implementations for Matrix + Matrix
|
||||||
macro_rules! impl_elementwise_op_matrix_all {
|
macro_rules! impl_elementwise_op_matrix_all {
|
||||||
($OpTrait:ident, $method:ident, $op:tt) => {
|
($OpTrait:ident, $method:ident, $op:tt) => {
|
||||||
@ -415,7 +432,7 @@ macro_rules! impl_elementwise_op_matrix_all {
|
|||||||
where T: Clone + $OpTrait<Output = T> {
|
where T: Clone + $OpTrait<Output = T> {
|
||||||
type Output = Matrix<T>;
|
type Output = Matrix<T>;
|
||||||
fn $method(self, rhs: &'b Matrix<T>) -> Matrix<T> {
|
fn $method(self, rhs: &'b Matrix<T>) -> Matrix<T> {
|
||||||
assert_eq!(self.rows, rhs.rows, "row count mismatch"); assert_eq!(self.cols, rhs.cols, "col count mismatch");
|
check_matrix_dims_for_ops(self, rhs);
|
||||||
let data = self.data.iter().cloned().zip(rhs.data.iter().cloned()).map(|(a, b)| a $op b).collect();
|
let data = self.data.iter().cloned().zip(rhs.data.iter().cloned()).map(|(a, b)| a $op b).collect();
|
||||||
Matrix { rows: self.rows, cols: self.cols, data }
|
Matrix { rows: self.rows, cols: self.cols, data }
|
||||||
}
|
}
|
||||||
@ -425,7 +442,7 @@ macro_rules! impl_elementwise_op_matrix_all {
|
|||||||
where T: Clone + $OpTrait<Output = T> {
|
where T: Clone + $OpTrait<Output = T> {
|
||||||
type Output = Matrix<T>;
|
type Output = Matrix<T>;
|
||||||
fn $method(mut self, rhs: &'b Matrix<T>) -> Matrix<T> { // Make self mutable for potential in-place modification
|
fn $method(mut self, rhs: &'b Matrix<T>) -> Matrix<T> { // Make self mutable for potential in-place modification
|
||||||
assert_eq!(self.rows, rhs.rows, "row count mismatch"); assert_eq!(self.cols, rhs.cols, "col count mismatch");
|
check_matrix_dims_for_ops(&self, rhs);
|
||||||
// Modify data in place
|
// Modify data in place
|
||||||
for (a, b) in self.data.iter_mut().zip(rhs.data.iter().cloned()) {
|
for (a, b) in self.data.iter_mut().zip(rhs.data.iter().cloned()) {
|
||||||
*a = a.clone() $op b; // Requires T: Clone for the *a = part
|
*a = a.clone() $op b; // Requires T: Clone for the *a = part
|
||||||
@ -442,7 +459,7 @@ macro_rules! impl_elementwise_op_matrix_all {
|
|||||||
where T: Clone + $OpTrait<Output = T> {
|
where T: Clone + $OpTrait<Output = T> {
|
||||||
type Output = Matrix<T>;
|
type Output = Matrix<T>;
|
||||||
fn $method(self, mut rhs: Matrix<T>) -> Matrix<T> { // Make rhs mutable
|
fn $method(self, mut rhs: Matrix<T>) -> Matrix<T> { // Make rhs mutable
|
||||||
assert_eq!(self.rows, rhs.rows, "row count mismatch"); assert_eq!(self.cols, rhs.cols, "col count mismatch");
|
check_matrix_dims_for_ops(self, &rhs);
|
||||||
// Modify rhs data in place
|
// Modify rhs data in place
|
||||||
for (a, b) in self.data.iter().cloned().zip(rhs.data.iter_mut()) {
|
for (a, b) in self.data.iter().cloned().zip(rhs.data.iter_mut()) {
|
||||||
*b = a $op b.clone(); // Requires T: Clone for the *b = part
|
*b = a $op b.clone(); // Requires T: Clone for the *b = part
|
||||||
@ -459,7 +476,7 @@ macro_rules! impl_elementwise_op_matrix_all {
|
|||||||
where T: Clone + $OpTrait<Output = T> {
|
where T: Clone + $OpTrait<Output = T> {
|
||||||
type Output = Matrix<T>;
|
type Output = Matrix<T>;
|
||||||
fn $method(mut self, rhs: Matrix<T>) -> Matrix<T> { // Make self mutable
|
fn $method(mut self, rhs: Matrix<T>) -> Matrix<T> { // Make self mutable
|
||||||
assert_eq!(self.rows, rhs.rows, "row count mismatch"); assert_eq!(self.cols, rhs.cols, "col count mismatch");
|
check_matrix_dims_for_ops(&self, &rhs);
|
||||||
// Modify self data in place
|
// Modify self data in place
|
||||||
for (a, b) in self.data.iter_mut().zip(rhs.data.into_iter()) {
|
for (a, b) in self.data.iter_mut().zip(rhs.data.into_iter()) {
|
||||||
*a = a.clone() $op b; // Requires T: Clone
|
*a = a.clone() $op b; // Requires T: Clone
|
||||||
@ -519,14 +536,15 @@ impl_elementwise_op_scalar_all!(Sub, sub, -);
|
|||||||
impl_elementwise_op_scalar_all!(Mul, mul, *);
|
impl_elementwise_op_scalar_all!(Mul, mul, *);
|
||||||
impl_elementwise_op_scalar_all!(Div, div, /);
|
impl_elementwise_op_scalar_all!(Div, div, /);
|
||||||
|
|
||||||
/// Generates element-wise bitwise operations for boolean matrices.
|
// --- Element-wise Bitwise Operations (BoolMatrix) ---
|
||||||
|
|
||||||
macro_rules! impl_bitwise_op_all {
|
macro_rules! impl_bitwise_op_all {
|
||||||
($OpTrait:ident, $method:ident, $op:tt) => {
|
($OpTrait:ident, $method:ident, $op:tt) => {
|
||||||
// &Matrix<bool> OP &Matrix<bool>
|
// &Matrix<bool> OP &Matrix<bool>
|
||||||
impl<'a, 'b> $OpTrait<&'b Matrix<bool>> for &'a Matrix<bool> {
|
impl<'a, 'b> $OpTrait<&'b Matrix<bool>> for &'a Matrix<bool> {
|
||||||
type Output = Matrix<bool>;
|
type Output = Matrix<bool>;
|
||||||
fn $method(self, rhs: &'b Matrix<bool>) -> Matrix<bool> {
|
fn $method(self, rhs: &'b Matrix<bool>) -> Matrix<bool> {
|
||||||
assert_eq!(self.rows, rhs.rows); assert_eq!(self.cols, rhs.cols);
|
check_matrix_dims_for_ops(self, rhs);
|
||||||
let data = self.data.iter().cloned().zip(rhs.data.iter().cloned()).map(|(a, b)| a $op b).collect();
|
let data = self.data.iter().cloned().zip(rhs.data.iter().cloned()).map(|(a, b)| a $op b).collect();
|
||||||
Matrix { rows: self.rows, cols: self.cols, data }
|
Matrix { rows: self.rows, cols: self.cols, data }
|
||||||
}
|
}
|
||||||
@ -535,17 +553,17 @@ macro_rules! impl_bitwise_op_all {
|
|||||||
impl<'b> $OpTrait<&'b Matrix<bool>> for Matrix<bool> {
|
impl<'b> $OpTrait<&'b Matrix<bool>> for Matrix<bool> {
|
||||||
type Output = Matrix<bool>;
|
type Output = Matrix<bool>;
|
||||||
fn $method(mut self, rhs: &'b Matrix<bool>) -> Matrix<bool> {
|
fn $method(mut self, rhs: &'b Matrix<bool>) -> Matrix<bool> {
|
||||||
assert_eq!(self.rows, rhs.rows); assert_eq!(self.cols, rhs.cols);
|
check_matrix_dims_for_ops(&self, rhs);
|
||||||
for (a, b) in self.data.iter_mut().zip(rhs.data.iter()) { *a = *a $op *b; } // bool is Copy
|
for (a, b) in self.data.iter_mut().zip(rhs.data.iter()) { *a = *a $op *b; } // bool is Copy
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// &Matrix<bool> OP Matrix<bool>
|
// &Matrix<bool> OP Matrix<bool>
|
||||||
impl<'a> $OpTrait<Matrix<bool>> for &'a Matrix<bool> {
|
impl<'a> $OpTrait<Matrix<bool>> for &'a Matrix<bool> {
|
||||||
type Output = Matrix<bool>;
|
type Output = Matrix<bool>;
|
||||||
fn $method(self, mut rhs: Matrix<bool>) -> Matrix<bool> {
|
fn $method(self, mut rhs: Matrix<bool>) -> Matrix<bool> {
|
||||||
assert_eq!(self.rows, rhs.rows); assert_eq!(self.cols, rhs.cols);
|
check_matrix_dims_for_ops(self, &rhs);
|
||||||
for (a, b) in self.data.iter().zip(rhs.data.iter_mut()) { *b = *a $op *b; } // bool is Copy
|
for (a, b) in self.data.iter().zip(rhs.data.iter_mut()) { *b = *a $op *b; } // bool is Copy
|
||||||
rhs
|
rhs
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -553,7 +571,7 @@ macro_rules! impl_bitwise_op_all {
|
|||||||
impl $OpTrait<Matrix<bool>> for Matrix<bool> {
|
impl $OpTrait<Matrix<bool>> for Matrix<bool> {
|
||||||
type Output = Matrix<bool>;
|
type Output = Matrix<bool>;
|
||||||
fn $method(mut self, rhs: Matrix<bool>) -> Matrix<bool> {
|
fn $method(mut self, rhs: Matrix<bool>) -> Matrix<bool> {
|
||||||
assert_eq!(self.rows, rhs.rows); assert_eq!(self.cols, rhs.cols);
|
check_matrix_dims_for_ops(&self, &rhs);
|
||||||
for (a, b) in self.data.iter_mut().zip(rhs.data.iter()) { *a = *a $op *b; } // bool is Copy
|
for (a, b) in self.data.iter_mut().zip(rhs.data.iter()) { *a = *a $op *b; } // bool is Copy
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user