mirror of
https://github.com/Magnus167/rustframe.git
synced 2025-08-20 04:00:01 +00:00
applied formatting
This commit is contained in:
parent
a3bb509202
commit
9b08eaeb35
@ -386,15 +386,31 @@ impl<T: Clone> Matrix<T> {
|
||||
|
||||
/// Creates a new matrix by broadcasting a 1-row matrix to a target shape.
|
||||
/// Panics if `self` is not a 1-row matrix or if `self.cols()` does not match `target_cols`.
|
||||
pub fn broadcast_row_to_target_shape(&self, target_rows: usize, target_cols: usize) -> Matrix<T> {
|
||||
assert_eq!(self.rows(), 1, "broadcast_row_to_target_shape can only be called on a 1-row matrix.");
|
||||
assert_eq!(self.cols(), target_cols, "Column count mismatch for broadcasting: source has {} columns, target has {} columns.", self.cols(), target_cols);
|
||||
pub fn broadcast_row_to_target_shape(
|
||||
&self,
|
||||
target_rows: usize,
|
||||
target_cols: usize,
|
||||
) -> Matrix<T> {
|
||||
assert_eq!(
|
||||
self.rows(),
|
||||
1,
|
||||
"broadcast_row_to_target_shape can only be called on a 1-row matrix."
|
||||
);
|
||||
assert_eq!(
|
||||
self.cols(),
|
||||
target_cols,
|
||||
"Column count mismatch for broadcasting: source has {} columns, target has {} columns.",
|
||||
self.cols(),
|
||||
target_cols
|
||||
);
|
||||
|
||||
let mut data = Vec::with_capacity(target_rows * target_cols);
|
||||
let original_row_data = self.row(0); // Get the single row data
|
||||
|
||||
for _ in 0..target_rows { // Repeat 'target_rows' times
|
||||
for value in &original_row_data { // Iterate over elements of the row
|
||||
for _ in 0..target_rows {
|
||||
// Repeat 'target_rows' times
|
||||
for value in &original_row_data {
|
||||
// Iterate over elements of the row
|
||||
data.push(value.clone());
|
||||
}
|
||||
}
|
||||
@ -2049,14 +2065,18 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "broadcast_row_to_target_shape can only be called on a 1-row matrix.")]
|
||||
#[should_panic(
|
||||
expected = "broadcast_row_to_target_shape can only be called on a 1-row matrix."
|
||||
)]
|
||||
fn test_broadcast_row_to_target_shape_panic_not_1_row() {
|
||||
let multi_row_matrix = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
|
||||
multi_row_matrix.broadcast_row_to_target_shape(3, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Column count mismatch for broadcasting: source has 3 columns, target has 4 columns.")]
|
||||
#[should_panic(
|
||||
expected = "Column count mismatch for broadcasting: source has 3 columns, target has 4 columns."
|
||||
)]
|
||||
fn test_broadcast_row_to_target_shape_panic_col_mismatch() {
|
||||
let single_row_matrix = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0], 1, 3);
|
||||
single_row_matrix.broadcast_row_to_target_shape(5, 4);
|
||||
|
Loading…
x
Reference in New Issue
Block a user