applied formatting

This commit is contained in:
Palash Tyagi 2025-07-12 00:55:44 +01:00
parent a3bb509202
commit 9b08eaeb35

View File

@ -386,15 +386,31 @@ impl<T: Clone> Matrix<T> {
/// Creates a new matrix by broadcasting a 1-row matrix to a target shape.
/// Panics if `self` is not a 1-row matrix or if `self.cols()` does not match `target_cols`.
pub fn broadcast_row_to_target_shape(&self, target_rows: usize, target_cols: usize) -> Matrix<T> {
assert_eq!(self.rows(), 1, "broadcast_row_to_target_shape can only be called on a 1-row matrix.");
assert_eq!(self.cols(), target_cols, "Column count mismatch for broadcasting: source has {} columns, target has {} columns.", self.cols(), target_cols);
pub fn broadcast_row_to_target_shape(
&self,
target_rows: usize,
target_cols: usize,
) -> Matrix<T> {
assert_eq!(
self.rows(),
1,
"broadcast_row_to_target_shape can only be called on a 1-row matrix."
);
assert_eq!(
self.cols(),
target_cols,
"Column count mismatch for broadcasting: source has {} columns, target has {} columns.",
self.cols(),
target_cols
);
let mut data = Vec::with_capacity(target_rows * target_cols);
let original_row_data = self.row(0); // Get the single row data
for _ in 0..target_rows { // Repeat 'target_rows' times
for value in &original_row_data { // Iterate over elements of the row
for _ in 0..target_rows {
// Repeat 'target_rows' times
for value in &original_row_data {
// Iterate over elements of the row
data.push(value.clone());
}
}
@ -2049,14 +2065,18 @@ mod tests {
}
#[test]
#[should_panic(expected = "broadcast_row_to_target_shape can only be called on a 1-row matrix.")]
#[should_panic(
expected = "broadcast_row_to_target_shape can only be called on a 1-row matrix."
)]
fn test_broadcast_row_to_target_shape_panic_not_1_row() {
let multi_row_matrix = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
multi_row_matrix.broadcast_row_to_target_shape(3, 2);
}
#[test]
#[should_panic(expected = "Column count mismatch for broadcasting: source has 3 columns, target has 4 columns.")]
#[should_panic(
expected = "Column count mismatch for broadcasting: source has 3 columns, target has 4 columns."
)]
fn test_broadcast_row_to_target_shape_panic_col_mismatch() {
let single_row_matrix = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0], 1, 3);
single_row_matrix.broadcast_row_to_target_shape(5, 4);