From 3cb68be0621bc2b50d4b9dc8c9d64a72b50c9b85 Mon Sep 17 00:00:00 2001 From: Palash Tyagi <23239946+Magnus167@users.noreply.github.com> Date: Sat, 3 May 2025 01:11:59 +0100 Subject: [PATCH] Refactor element-wise arithmetic operations to use a centralized dimension check for improved error handling --- src/matrix/mat.rs | 40 +++++++++++++++++++++++++++++----------- 1 file changed, 29 insertions(+), 11 deletions(-) diff --git a/src/matrix/mat.rs b/src/matrix/mat.rs index f9db3d9..d1be2d5 100644 --- a/src/matrix/mat.rs +++ b/src/matrix/mat.rs @@ -407,6 +407,23 @@ impl_elementwise_cmp! { ge_elem => >=, } +// --- Element-wise Arithmetic Operations (Macros generating all ownership variants) --- + +fn check_matrix_dims_for_ops(lhs: &Matrix, rhs: &Matrix) { + assert!( + lhs.rows == rhs.rows, + "Row count mismatch: left has {} rows, right has {} rows", + lhs.rows, + rhs.rows + ); + assert!( + lhs.cols == rhs.cols, + "Column count mismatch: left has {} columns, right has {} columns", + lhs.cols, + rhs.cols + ); +} + /// Generates element-wise arithmetic implementations for Matrix + Matrix macro_rules! impl_elementwise_op_matrix_all { ($OpTrait:ident, $method:ident, $op:tt) => { @@ -415,7 +432,7 @@ macro_rules! impl_elementwise_op_matrix_all { where T: Clone + $OpTrait { type Output = Matrix; fn $method(self, rhs: &'b Matrix) -> Matrix { - assert_eq!(self.rows, rhs.rows, "row count mismatch"); assert_eq!(self.cols, rhs.cols, "col count mismatch"); + check_matrix_dims_for_ops(self, rhs); let data = self.data.iter().cloned().zip(rhs.data.iter().cloned()).map(|(a, b)| a $op b).collect(); Matrix { rows: self.rows, cols: self.cols, data } } @@ -425,7 +442,7 @@ macro_rules! impl_elementwise_op_matrix_all { where T: Clone + $OpTrait { type Output = Matrix; fn $method(mut self, rhs: &'b Matrix) -> Matrix { // Make self mutable for potential in-place modification - assert_eq!(self.rows, rhs.rows, "row count mismatch"); assert_eq!(self.cols, rhs.cols, "col count mismatch"); + check_matrix_dims_for_ops(&self, rhs); // Modify data in place for (a, b) in self.data.iter_mut().zip(rhs.data.iter().cloned()) { *a = a.clone() $op b; // Requires T: Clone for the *a = part @@ -442,7 +459,7 @@ macro_rules! impl_elementwise_op_matrix_all { where T: Clone + $OpTrait { type Output = Matrix; fn $method(self, mut rhs: Matrix) -> Matrix { // Make rhs mutable - assert_eq!(self.rows, rhs.rows, "row count mismatch"); assert_eq!(self.cols, rhs.cols, "col count mismatch"); + check_matrix_dims_for_ops(self, &rhs); // Modify rhs data in place for (a, b) in self.data.iter().cloned().zip(rhs.data.iter_mut()) { *b = a $op b.clone(); // Requires T: Clone for the *b = part @@ -459,7 +476,7 @@ macro_rules! impl_elementwise_op_matrix_all { where T: Clone + $OpTrait { type Output = Matrix; fn $method(mut self, rhs: Matrix) -> Matrix { // Make self mutable - assert_eq!(self.rows, rhs.rows, "row count mismatch"); assert_eq!(self.cols, rhs.cols, "col count mismatch"); + check_matrix_dims_for_ops(&self, &rhs); // Modify self data in place for (a, b) in self.data.iter_mut().zip(rhs.data.into_iter()) { *a = a.clone() $op b; // Requires T: Clone @@ -519,14 +536,15 @@ impl_elementwise_op_scalar_all!(Sub, sub, -); impl_elementwise_op_scalar_all!(Mul, mul, *); impl_elementwise_op_scalar_all!(Div, div, /); -/// Generates element-wise bitwise operations for boolean matrices. +// --- Element-wise Bitwise Operations (BoolMatrix) --- + macro_rules! impl_bitwise_op_all { ($OpTrait:ident, $method:ident, $op:tt) => { // &Matrix OP &Matrix impl<'a, 'b> $OpTrait<&'b Matrix> for &'a Matrix { type Output = Matrix; fn $method(self, rhs: &'b Matrix) -> Matrix { - assert_eq!(self.rows, rhs.rows); assert_eq!(self.cols, rhs.cols); + check_matrix_dims_for_ops(self, rhs); let data = self.data.iter().cloned().zip(rhs.data.iter().cloned()).map(|(a, b)| a $op b).collect(); Matrix { rows: self.rows, cols: self.cols, data } } @@ -535,17 +553,17 @@ macro_rules! impl_bitwise_op_all { impl<'b> $OpTrait<&'b Matrix> for Matrix { type Output = Matrix; fn $method(mut self, rhs: &'b Matrix) -> Matrix { - assert_eq!(self.rows, rhs.rows); assert_eq!(self.cols, rhs.cols); + check_matrix_dims_for_ops(&self, rhs); for (a, b) in self.data.iter_mut().zip(rhs.data.iter()) { *a = *a $op *b; } // bool is Copy self } } // &Matrix OP Matrix - impl<'a> $OpTrait> for &'a Matrix { + impl<'a> $OpTrait> for &'a Matrix { type Output = Matrix; fn $method(self, mut rhs: Matrix) -> Matrix { - assert_eq!(self.rows, rhs.rows); assert_eq!(self.cols, rhs.cols); - for (a, b) in self.data.iter().zip(rhs.data.iter_mut()) { *b = *a $op *b; } // bool is Copy + check_matrix_dims_for_ops(self, &rhs); + for (a, b) in self.data.iter().zip(rhs.data.iter_mut()) { *b = *a $op *b; } // bool is Copy rhs } } @@ -553,7 +571,7 @@ macro_rules! impl_bitwise_op_all { impl $OpTrait> for Matrix { type Output = Matrix; fn $method(mut self, rhs: Matrix) -> Matrix { - assert_eq!(self.rows, rhs.rows); assert_eq!(self.cols, rhs.cols); + check_matrix_dims_for_ops(&self, &rhs); for (a, b) in self.data.iter_mut().zip(rhs.data.iter()) { *a = *a $op *b; } // bool is Copy self }