diff --git a/src/matrix/mat.rs b/src/matrix/mat.rs index c7b0893..649d831 100644 --- a/src/matrix/mat.rs +++ b/src/matrix/mat.rs @@ -327,6 +327,21 @@ impl Matrix { self.data = new_data; self.rows = new_rows; } + + /// Return a new matrix where row 0 of `self` is repeated `n` times. + pub fn repeat_rows(&self, n: usize) -> Matrix + where + T: Clone, + { + let mut data = Vec::with_capacity(n * self.cols()); + let zeroth_row = self.row(0); + for value in &zeroth_row { + for _ in 0..n { + data.push(value.clone()); // Clone each element + } + } + Matrix::from_vec(data, n, self.cols) + } } impl Matrix { @@ -1158,13 +1173,24 @@ mod tests { } #[test] - fn test_shape(){ + fn test_shape() { let ma = static_test_matrix_2x4(); assert_eq!(ma.shape(), (2, 4)); assert_eq!(ma.rows(), 2); assert_eq!(ma.cols(), 4); } + #[test] + fn test_repeat_rows() { + let ma = static_test_matrix(); + // Returns a new matrix where row 0 of `self` is repeated `n` times. + let repeated = ma.repeat_rows(3); + // assert all rows are equal to the first row + for r in 0..repeated.rows() { + assert_eq!(repeated.row(r), ma.row(0)); + } + } + #[test] #[should_panic(expected = "row index 3 out of bounds for 3 rows")] fn test_row_out_of_bounds() {