diff --git a/src/matrix/mat.rs b/src/matrix/mat.rs index 509cc8b..db4417f 100644 --- a/src/matrix/mat.rs +++ b/src/matrix/mat.rs @@ -383,6 +383,25 @@ impl Matrix { data: vec![value; rows * cols], // Fill with the specified value } } + + /// Creates a new matrix by broadcasting a 1-row matrix to a target shape. + /// Panics if `self` is not a 1-row matrix or if `self.cols()` does not match `target_cols`. + pub fn broadcast_row_to_target_shape(&self, target_rows: usize, target_cols: usize) -> Matrix { + assert_eq!(self.rows(), 1, "broadcast_row_to_target_shape can only be called on a 1-row matrix."); + assert_eq!(self.cols(), target_cols, "Column count mismatch for broadcasting: source has {} columns, target has {} columns.", self.cols(), target_cols); + + let mut data = Vec::with_capacity(target_rows * target_cols); + let original_row_data = self.row(0); // Get the single row data + + for _ in 0..target_rows { // Repeat 'target_rows' times + for value in &original_row_data { // Iterate over elements of the row + data.push(value.clone()); + } + } + // The data is now in row-major order for the new matrix. + // We need to convert it to column-major for Matrix::from_vec. + Matrix::from_rows_vec(data, target_rows, target_cols) + } } impl Matrix { @@ -1992,4 +2011,47 @@ mod tests { assert!(value.is_nan(), "Expected NaN, got {}", value); } } + + #[test] + fn test_broadcast_row_to_target_shape_basic() { + let single_row_matrix = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0], 1, 3); + let target_rows = 5; + let target_cols = 3; + + let broadcasted = single_row_matrix.broadcast_row_to_target_shape(target_rows, target_cols); + + assert_eq!(broadcasted.rows(), target_rows); + assert_eq!(broadcasted.cols(), target_cols); + + for r in 0..target_rows { + assert_eq!(broadcasted.row(r), vec![1.0, 2.0, 3.0]); + } + } + + #[test] + fn test_broadcast_row_to_target_shape_single_row() { + let single_row_matrix = Matrix::from_rows_vec(vec![10.0, 20.0], 1, 2); + let target_rows = 1; + let target_cols = 2; + + let broadcasted = single_row_matrix.broadcast_row_to_target_shape(target_rows, target_cols); + + assert_eq!(broadcasted.rows(), target_rows); + assert_eq!(broadcasted.cols(), target_cols); + assert_eq!(broadcasted.row(0), vec![10.0, 20.0]); + } + + #[test] + #[should_panic(expected = "broadcast_row_to_target_shape can only be called on a 1-row matrix.")] + fn test_broadcast_row_to_target_shape_panic_not_1_row() { + let multi_row_matrix = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2); + multi_row_matrix.broadcast_row_to_target_shape(3, 2); + } + + #[test] + #[should_panic(expected = "Column count mismatch for broadcasting: source has 3 columns, target has 4 columns.")] + fn test_broadcast_row_to_target_shape_panic_col_mismatch() { + let single_row_matrix = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0], 1, 3); + single_row_matrix.broadcast_row_to_target_shape(5, 4); + } }