Add repeat_rows method to Matrix and corresponding unit test

This commit is contained in:
Palash Tyagi 2025-07-06 19:16:46 +01:00
parent 85154a3be0
commit 2ca496cfd1

View File

@ -327,6 +327,21 @@ impl<T: Clone> Matrix<T> {
self.data = new_data; self.data = new_data;
self.rows = new_rows; self.rows = new_rows;
} }
/// Return a new matrix where row 0 of `self` is repeated `n` times.
pub fn repeat_rows(&self, n: usize) -> Matrix<T>
where
T: Clone,
{
let mut data = Vec::with_capacity(n * self.cols());
let zeroth_row = self.row(0);
for value in &zeroth_row {
for _ in 0..n {
data.push(value.clone()); // Clone each element
}
}
Matrix::from_vec(data, n, self.cols)
}
} }
impl Matrix<f64> { impl Matrix<f64> {
@ -1165,6 +1180,17 @@ mod tests {
assert_eq!(ma.cols(), 4); assert_eq!(ma.cols(), 4);
} }
#[test]
fn test_repeat_rows() {
let ma = static_test_matrix();
// Returns a new matrix where row 0 of `self` is repeated `n` times.
let repeated = ma.repeat_rows(3);
// assert all rows are equal to the first row
for r in 0..repeated.rows() {
assert_eq!(repeated.row(r), ma.row(0));
}
}
#[test] #[test]
#[should_panic(expected = "row index 3 out of bounds for 3 rows")] #[should_panic(expected = "row index 3 out of bounds for 3 rows")]
fn test_row_out_of_bounds() { fn test_row_out_of_bounds() {