mirror of
https://github.com/Magnus167/rustframe.git
synced 2025-08-20 04:00:01 +00:00
applied formatting
This commit is contained in:
parent
a3bb509202
commit
9b08eaeb35
@ -386,15 +386,31 @@ impl<T: Clone> Matrix<T> {
|
|||||||
|
|
||||||
/// Creates a new matrix by broadcasting a 1-row matrix to a target shape.
|
/// Creates a new matrix by broadcasting a 1-row matrix to a target shape.
|
||||||
/// Panics if `self` is not a 1-row matrix or if `self.cols()` does not match `target_cols`.
|
/// Panics if `self` is not a 1-row matrix or if `self.cols()` does not match `target_cols`.
|
||||||
pub fn broadcast_row_to_target_shape(&self, target_rows: usize, target_cols: usize) -> Matrix<T> {
|
pub fn broadcast_row_to_target_shape(
|
||||||
assert_eq!(self.rows(), 1, "broadcast_row_to_target_shape can only be called on a 1-row matrix.");
|
&self,
|
||||||
assert_eq!(self.cols(), target_cols, "Column count mismatch for broadcasting: source has {} columns, target has {} columns.", self.cols(), target_cols);
|
target_rows: usize,
|
||||||
|
target_cols: usize,
|
||||||
|
) -> Matrix<T> {
|
||||||
|
assert_eq!(
|
||||||
|
self.rows(),
|
||||||
|
1,
|
||||||
|
"broadcast_row_to_target_shape can only be called on a 1-row matrix."
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
self.cols(),
|
||||||
|
target_cols,
|
||||||
|
"Column count mismatch for broadcasting: source has {} columns, target has {} columns.",
|
||||||
|
self.cols(),
|
||||||
|
target_cols
|
||||||
|
);
|
||||||
|
|
||||||
let mut data = Vec::with_capacity(target_rows * target_cols);
|
let mut data = Vec::with_capacity(target_rows * target_cols);
|
||||||
let original_row_data = self.row(0); // Get the single row data
|
let original_row_data = self.row(0); // Get the single row data
|
||||||
|
|
||||||
for _ in 0..target_rows { // Repeat 'target_rows' times
|
for _ in 0..target_rows {
|
||||||
for value in &original_row_data { // Iterate over elements of the row
|
// Repeat 'target_rows' times
|
||||||
|
for value in &original_row_data {
|
||||||
|
// Iterate over elements of the row
|
||||||
data.push(value.clone());
|
data.push(value.clone());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -2049,14 +2065,18 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[should_panic(expected = "broadcast_row_to_target_shape can only be called on a 1-row matrix.")]
|
#[should_panic(
|
||||||
|
expected = "broadcast_row_to_target_shape can only be called on a 1-row matrix."
|
||||||
|
)]
|
||||||
fn test_broadcast_row_to_target_shape_panic_not_1_row() {
|
fn test_broadcast_row_to_target_shape_panic_not_1_row() {
|
||||||
let multi_row_matrix = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
|
let multi_row_matrix = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
|
||||||
multi_row_matrix.broadcast_row_to_target_shape(3, 2);
|
multi_row_matrix.broadcast_row_to_target_shape(3, 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[should_panic(expected = "Column count mismatch for broadcasting: source has 3 columns, target has 4 columns.")]
|
#[should_panic(
|
||||||
|
expected = "Column count mismatch for broadcasting: source has 3 columns, target has 4 columns."
|
||||||
|
)]
|
||||||
fn test_broadcast_row_to_target_shape_panic_col_mismatch() {
|
fn test_broadcast_row_to_target_shape_panic_col_mismatch() {
|
||||||
let single_row_matrix = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0], 1, 3);
|
let single_row_matrix = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0], 1, 3);
|
||||||
single_row_matrix.broadcast_row_to_target_shape(5, 4);
|
single_row_matrix.broadcast_row_to_target_shape(5, 4);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user