diff --git a/src/matrix/mat.rs b/src/matrix/mat.rs index 7e7d46c..6709f07 100644 --- a/src/matrix/mat.rs +++ b/src/matrix/mat.rs @@ -386,15 +386,31 @@ impl Matrix { /// Creates a new matrix by broadcasting a 1-row matrix to a target shape. /// Panics if `self` is not a 1-row matrix or if `self.cols()` does not match `target_cols`. - pub fn broadcast_row_to_target_shape(&self, target_rows: usize, target_cols: usize) -> Matrix { - assert_eq!(self.rows(), 1, "broadcast_row_to_target_shape can only be called on a 1-row matrix."); - assert_eq!(self.cols(), target_cols, "Column count mismatch for broadcasting: source has {} columns, target has {} columns.", self.cols(), target_cols); + pub fn broadcast_row_to_target_shape( + &self, + target_rows: usize, + target_cols: usize, + ) -> Matrix { + assert_eq!( + self.rows(), + 1, + "broadcast_row_to_target_shape can only be called on a 1-row matrix." + ); + assert_eq!( + self.cols(), + target_cols, + "Column count mismatch for broadcasting: source has {} columns, target has {} columns.", + self.cols(), + target_cols + ); let mut data = Vec::with_capacity(target_rows * target_cols); let original_row_data = self.row(0); // Get the single row data - for _ in 0..target_rows { // Repeat 'target_rows' times - for value in &original_row_data { // Iterate over elements of the row + for _ in 0..target_rows { + // Repeat 'target_rows' times + for value in &original_row_data { + // Iterate over elements of the row data.push(value.clone()); } } @@ -2049,14 +2065,18 @@ mod tests { } #[test] - #[should_panic(expected = "broadcast_row_to_target_shape can only be called on a 1-row matrix.")] + #[should_panic( + expected = "broadcast_row_to_target_shape can only be called on a 1-row matrix." + )] fn test_broadcast_row_to_target_shape_panic_not_1_row() { let multi_row_matrix = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2); multi_row_matrix.broadcast_row_to_target_shape(3, 2); } #[test] - #[should_panic(expected = "Column count mismatch for broadcasting: source has 3 columns, target has 4 columns.")] + #[should_panic( + expected = "Column count mismatch for broadcasting: source has 3 columns, target has 4 columns." + )] fn test_broadcast_row_to_target_shape_panic_col_mismatch() { let single_row_matrix = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0], 1, 3); single_row_matrix.broadcast_row_to_target_shape(5, 4);