Refactor element-wise arithmetic operations to use a centralized dimension check for improved error handling

This commit is contained in:
Palash Tyagi 2025-05-03 01:11:59 +01:00
parent ca734fbedf
commit 3cb68be062

View File

@ -407,6 +407,23 @@ impl_elementwise_cmp! {
ge_elem => >=,
}
// --- Element-wise Arithmetic Operations (Macros generating all ownership variants) ---
fn check_matrix_dims_for_ops<T>(lhs: &Matrix<T>, rhs: &Matrix<T>) {
assert!(
lhs.rows == rhs.rows,
"Row count mismatch: left has {} rows, right has {} rows",
lhs.rows,
rhs.rows
);
assert!(
lhs.cols == rhs.cols,
"Column count mismatch: left has {} columns, right has {} columns",
lhs.cols,
rhs.cols
);
}
/// Generates element-wise arithmetic implementations for Matrix + Matrix
macro_rules! impl_elementwise_op_matrix_all {
($OpTrait:ident, $method:ident, $op:tt) => {
@ -415,7 +432,7 @@ macro_rules! impl_elementwise_op_matrix_all {
where T: Clone + $OpTrait<Output = T> {
type Output = Matrix<T>;
fn $method(self, rhs: &'b Matrix<T>) -> Matrix<T> {
assert_eq!(self.rows, rhs.rows, "row count mismatch"); assert_eq!(self.cols, rhs.cols, "col count mismatch");
check_matrix_dims_for_ops(self, rhs);
let data = self.data.iter().cloned().zip(rhs.data.iter().cloned()).map(|(a, b)| a $op b).collect();
Matrix { rows: self.rows, cols: self.cols, data }
}
@ -425,7 +442,7 @@ macro_rules! impl_elementwise_op_matrix_all {
where T: Clone + $OpTrait<Output = T> {
type Output = Matrix<T>;
fn $method(mut self, rhs: &'b Matrix<T>) -> Matrix<T> { // Make self mutable for potential in-place modification
assert_eq!(self.rows, rhs.rows, "row count mismatch"); assert_eq!(self.cols, rhs.cols, "col count mismatch");
check_matrix_dims_for_ops(&self, rhs);
// Modify data in place
for (a, b) in self.data.iter_mut().zip(rhs.data.iter().cloned()) {
*a = a.clone() $op b; // Requires T: Clone for the *a = part
@ -442,7 +459,7 @@ macro_rules! impl_elementwise_op_matrix_all {
where T: Clone + $OpTrait<Output = T> {
type Output = Matrix<T>;
fn $method(self, mut rhs: Matrix<T>) -> Matrix<T> { // Make rhs mutable
assert_eq!(self.rows, rhs.rows, "row count mismatch"); assert_eq!(self.cols, rhs.cols, "col count mismatch");
check_matrix_dims_for_ops(self, &rhs);
// Modify rhs data in place
for (a, b) in self.data.iter().cloned().zip(rhs.data.iter_mut()) {
*b = a $op b.clone(); // Requires T: Clone for the *b = part
@ -459,7 +476,7 @@ macro_rules! impl_elementwise_op_matrix_all {
where T: Clone + $OpTrait<Output = T> {
type Output = Matrix<T>;
fn $method(mut self, rhs: Matrix<T>) -> Matrix<T> { // Make self mutable
assert_eq!(self.rows, rhs.rows, "row count mismatch"); assert_eq!(self.cols, rhs.cols, "col count mismatch");
check_matrix_dims_for_ops(&self, &rhs);
// Modify self data in place
for (a, b) in self.data.iter_mut().zip(rhs.data.into_iter()) {
*a = a.clone() $op b; // Requires T: Clone
@ -519,14 +536,15 @@ impl_elementwise_op_scalar_all!(Sub, sub, -);
impl_elementwise_op_scalar_all!(Mul, mul, *);
impl_elementwise_op_scalar_all!(Div, div, /);
/// Generates element-wise bitwise operations for boolean matrices.
// --- Element-wise Bitwise Operations (BoolMatrix) ---
macro_rules! impl_bitwise_op_all {
($OpTrait:ident, $method:ident, $op:tt) => {
// &Matrix<bool> OP &Matrix<bool>
impl<'a, 'b> $OpTrait<&'b Matrix<bool>> for &'a Matrix<bool> {
type Output = Matrix<bool>;
fn $method(self, rhs: &'b Matrix<bool>) -> Matrix<bool> {
assert_eq!(self.rows, rhs.rows); assert_eq!(self.cols, rhs.cols);
check_matrix_dims_for_ops(self, rhs);
let data = self.data.iter().cloned().zip(rhs.data.iter().cloned()).map(|(a, b)| a $op b).collect();
Matrix { rows: self.rows, cols: self.cols, data }
}
@ -535,17 +553,17 @@ macro_rules! impl_bitwise_op_all {
impl<'b> $OpTrait<&'b Matrix<bool>> for Matrix<bool> {
type Output = Matrix<bool>;
fn $method(mut self, rhs: &'b Matrix<bool>) -> Matrix<bool> {
assert_eq!(self.rows, rhs.rows); assert_eq!(self.cols, rhs.cols);
check_matrix_dims_for_ops(&self, rhs);
for (a, b) in self.data.iter_mut().zip(rhs.data.iter()) { *a = *a $op *b; } // bool is Copy
self
}
}
// &Matrix<bool> OP Matrix<bool>
impl<'a> $OpTrait<Matrix<bool>> for &'a Matrix<bool> {
impl<'a> $OpTrait<Matrix<bool>> for &'a Matrix<bool> {
type Output = Matrix<bool>;
fn $method(self, mut rhs: Matrix<bool>) -> Matrix<bool> {
assert_eq!(self.rows, rhs.rows); assert_eq!(self.cols, rhs.cols);
for (a, b) in self.data.iter().zip(rhs.data.iter_mut()) { *b = *a $op *b; } // bool is Copy
check_matrix_dims_for_ops(self, &rhs);
for (a, b) in self.data.iter().zip(rhs.data.iter_mut()) { *b = *a $op *b; } // bool is Copy
rhs
}
}
@ -553,7 +571,7 @@ macro_rules! impl_bitwise_op_all {
impl $OpTrait<Matrix<bool>> for Matrix<bool> {
type Output = Matrix<bool>;
fn $method(mut self, rhs: Matrix<bool>) -> Matrix<bool> {
assert_eq!(self.rows, rhs.rows); assert_eq!(self.cols, rhs.cols);
check_matrix_dims_for_ops(&self, &rhs);
for (a, b) in self.data.iter_mut().zip(rhs.data.iter()) { *a = *a $op *b; } // bool is Copy
self
}