mirror of
https://github.com/Magnus167/rustframe.git
synced 2025-08-20 04:19:59 +00:00
Add broadcasting functionality for 1-row matrices with tests
This commit is contained in:
parent
37c0d312e5
commit
58b0a5f0d9
@ -383,6 +383,25 @@ impl<T: Clone> Matrix<T> {
|
|||||||
data: vec![value; rows * cols], // Fill with the specified value
|
data: vec![value; rows * cols], // Fill with the specified value
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Creates a new matrix by broadcasting a 1-row matrix to a target shape.
|
||||||
|
/// Panics if `self` is not a 1-row matrix or if `self.cols()` does not match `target_cols`.
|
||||||
|
pub fn broadcast_row_to_target_shape(&self, target_rows: usize, target_cols: usize) -> Matrix<T> {
|
||||||
|
assert_eq!(self.rows(), 1, "broadcast_row_to_target_shape can only be called on a 1-row matrix.");
|
||||||
|
assert_eq!(self.cols(), target_cols, "Column count mismatch for broadcasting: source has {} columns, target has {} columns.", self.cols(), target_cols);
|
||||||
|
|
||||||
|
let mut data = Vec::with_capacity(target_rows * target_cols);
|
||||||
|
let original_row_data = self.row(0); // Get the single row data
|
||||||
|
|
||||||
|
for _ in 0..target_rows { // Repeat 'target_rows' times
|
||||||
|
for value in &original_row_data { // Iterate over elements of the row
|
||||||
|
data.push(value.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// The data is now in row-major order for the new matrix.
|
||||||
|
// We need to convert it to column-major for Matrix::from_vec.
|
||||||
|
Matrix::from_rows_vec(data, target_rows, target_cols)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Matrix<f64> {
|
impl Matrix<f64> {
|
||||||
@ -1992,4 +2011,47 @@ mod tests {
|
|||||||
assert!(value.is_nan(), "Expected NaN, got {}", value);
|
assert!(value.is_nan(), "Expected NaN, got {}", value);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_broadcast_row_to_target_shape_basic() {
|
||||||
|
let single_row_matrix = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0], 1, 3);
|
||||||
|
let target_rows = 5;
|
||||||
|
let target_cols = 3;
|
||||||
|
|
||||||
|
let broadcasted = single_row_matrix.broadcast_row_to_target_shape(target_rows, target_cols);
|
||||||
|
|
||||||
|
assert_eq!(broadcasted.rows(), target_rows);
|
||||||
|
assert_eq!(broadcasted.cols(), target_cols);
|
||||||
|
|
||||||
|
for r in 0..target_rows {
|
||||||
|
assert_eq!(broadcasted.row(r), vec![1.0, 2.0, 3.0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_broadcast_row_to_target_shape_single_row() {
|
||||||
|
let single_row_matrix = Matrix::from_rows_vec(vec![10.0, 20.0], 1, 2);
|
||||||
|
let target_rows = 1;
|
||||||
|
let target_cols = 2;
|
||||||
|
|
||||||
|
let broadcasted = single_row_matrix.broadcast_row_to_target_shape(target_rows, target_cols);
|
||||||
|
|
||||||
|
assert_eq!(broadcasted.rows(), target_rows);
|
||||||
|
assert_eq!(broadcasted.cols(), target_cols);
|
||||||
|
assert_eq!(broadcasted.row(0), vec![10.0, 20.0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[should_panic(expected = "broadcast_row_to_target_shape can only be called on a 1-row matrix.")]
|
||||||
|
fn test_broadcast_row_to_target_shape_panic_not_1_row() {
|
||||||
|
let multi_row_matrix = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
|
||||||
|
multi_row_matrix.broadcast_row_to_target_shape(3, 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[should_panic(expected = "Column count mismatch for broadcasting: source has 3 columns, target has 4 columns.")]
|
||||||
|
fn test_broadcast_row_to_target_shape_panic_col_mismatch() {
|
||||||
|
let single_row_matrix = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0], 1, 3);
|
||||||
|
single_row_matrix.broadcast_row_to_target_shape(5, 4);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user