diff --git a/src/matrix/mat.rs b/src/matrix/mat.rs index 649d831..8eded45 100644 --- a/src/matrix/mat.rs +++ b/src/matrix/mat.rs @@ -198,6 +198,18 @@ impl Matrix { row_data } + #[inline] + pub fn row_mut(&mut self, r: usize) -> &mut [T] { + assert!( + r < self.rows, + "row index {} out of bounds for {} rows", + r, + self.rows + ); + let start = r; + &mut self.data[start..start + self.cols] + } + /// Deletes a row from the matrix. Panics on out-of-bounds. /// This is O(N) where N is the number of elements, as it rebuilds the data vec. pub fn delete_row(&mut self, row: usize) { @@ -1172,6 +1184,25 @@ mod tests { assert_eq!(ma.row(2), &[3, 6, 9]); } + #[test] + fn test_row_mut() { + let mut ma = static_test_matrix(); + let row1_mut = ma.row_mut(1); + row1_mut[0] = 20; + row1_mut[1] = 50; + row1_mut[2] = 80; + + assert_eq!(ma.row(1), &[20, 50, 80]); + assert_eq!(ma.data(), &[1, 2, 3, 20, 50, 80, 7, 8, 9]); + } + + #[test] + #[should_panic(expected = "row index 3 out of bounds for 3 rows")] + fn test_row_mut_out_of_bounds() { + let mut ma = static_test_matrix(); + ma.row_mut(3); + } + #[test] fn test_shape() { let ma = static_test_matrix_2x4();