Add broadcasting functionality for 1-row matrices with tests

This commit is contained in:
Palash Tyagi 2025-07-12 00:22:22 +01:00
parent 37c0d312e5
commit 58b0a5f0d9

View File

@ -383,6 +383,25 @@ impl<T: Clone> Matrix<T> {
data: vec![value; rows * cols], // Fill with the specified value data: vec![value; rows * cols], // Fill with the specified value
} }
} }
/// Creates a new matrix by broadcasting a 1-row matrix to a target shape.
/// Panics if `self` is not a 1-row matrix or if `self.cols()` does not match `target_cols`.
pub fn broadcast_row_to_target_shape(&self, target_rows: usize, target_cols: usize) -> Matrix<T> {
assert_eq!(self.rows(), 1, "broadcast_row_to_target_shape can only be called on a 1-row matrix.");
assert_eq!(self.cols(), target_cols, "Column count mismatch for broadcasting: source has {} columns, target has {} columns.", self.cols(), target_cols);
let mut data = Vec::with_capacity(target_rows * target_cols);
let original_row_data = self.row(0); // Get the single row data
for _ in 0..target_rows { // Repeat 'target_rows' times
for value in &original_row_data { // Iterate over elements of the row
data.push(value.clone());
}
}
// The data is now in row-major order for the new matrix.
// We need to convert it to column-major for Matrix::from_vec.
Matrix::from_rows_vec(data, target_rows, target_cols)
}
} }
impl Matrix<f64> { impl Matrix<f64> {
@ -1992,4 +2011,47 @@ mod tests {
assert!(value.is_nan(), "Expected NaN, got {}", value); assert!(value.is_nan(), "Expected NaN, got {}", value);
} }
} }
#[test]
fn test_broadcast_row_to_target_shape_basic() {
let single_row_matrix = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0], 1, 3);
let target_rows = 5;
let target_cols = 3;
let broadcasted = single_row_matrix.broadcast_row_to_target_shape(target_rows, target_cols);
assert_eq!(broadcasted.rows(), target_rows);
assert_eq!(broadcasted.cols(), target_cols);
for r in 0..target_rows {
assert_eq!(broadcasted.row(r), vec![1.0, 2.0, 3.0]);
}
}
#[test]
fn test_broadcast_row_to_target_shape_single_row() {
let single_row_matrix = Matrix::from_rows_vec(vec![10.0, 20.0], 1, 2);
let target_rows = 1;
let target_cols = 2;
let broadcasted = single_row_matrix.broadcast_row_to_target_shape(target_rows, target_cols);
assert_eq!(broadcasted.rows(), target_rows);
assert_eq!(broadcasted.cols(), target_cols);
assert_eq!(broadcasted.row(0), vec![10.0, 20.0]);
}
#[test]
#[should_panic(expected = "broadcast_row_to_target_shape can only be called on a 1-row matrix.")]
fn test_broadcast_row_to_target_shape_panic_not_1_row() {
let multi_row_matrix = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
multi_row_matrix.broadcast_row_to_target_shape(3, 2);
}
#[test]
#[should_panic(expected = "Column count mismatch for broadcasting: source has 3 columns, target has 4 columns.")]
fn test_broadcast_row_to_target_shape_panic_col_mismatch() {
let single_row_matrix = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0], 1, 3);
single_row_matrix.broadcast_row_to_target_shape(5, 4);
}
} }