diff --git a/src/matrix/mat.rs b/src/matrix/mat.rs index a951c05..82acb38 100644 --- a/src/matrix/mat.rs +++ b/src/matrix/mat.rs @@ -16,7 +16,7 @@ impl Matrix { let cols = cols_data.len(); assert!(cols > 0, "need at least one column"); // Handle empty cols_data - let rows = cols_data.get(0).map_or(0, |c| c.len()); + let rows = cols_data.get(0).map_or(0, |c| c.len()); // Allow 0-row matrices if columns are empty, but not 0-col matrices if rows > 0 assert!( rows > 0 || cols == 0, @@ -89,9 +89,12 @@ impl Matrix { self.cols } + /// Get element reference (immutable). Panics on out-of-bounds. pub fn get(&self, r: usize, c: usize) -> &T { &self[(r, c)] } + + /// Get element reference (mutable). Panics on out-of-bounds. pub fn get_mut(&mut self, r: usize, c: usize) -> &mut T { &mut self[(r, c)] } @@ -107,6 +110,7 @@ impl Matrix { let start = c * self.rows; &self.data[start..start + self.rows] } + #[inline] pub fn column_mut(&mut self, c: usize) -> &mut [T] { assert!( @@ -130,7 +134,7 @@ impl Matrix { }) } - /// Swaps two columns in the matrix. + /// Swaps two columns in the matrix. Panics on out-of-bounds. pub fn swap_columns(&mut self, c1: usize, c2: usize) { assert!( c1 < self.cols, @@ -183,7 +187,7 @@ impl Matrix { impl Matrix { /// Adds a column to the matrix at the specified index. pub fn add_column(&mut self, index: usize, column: Vec) { - assert!(index <= self.cols, "add_column index {} out of bounds for {} columns", index, self.cols); + assert!(index <= self.cols,"add_column index {} out of bounds for {} columns",index,self.cols); assert_eq!(column.len(), self.rows, "column length mismatch"); for (r, value) in column.into_iter().enumerate() { @@ -194,8 +198,8 @@ impl Matrix { /// Adds a row to the matrix at the specified index. pub fn add_row(&mut self, index: usize, row: Vec) { - assert!(index <= self.rows, "add_row index {} out of bounds for {} rows", index, self.rows); - assert_eq!(row.len(), self.cols, "row length mismatch"); + assert!(index <= self.rows,"add_row index {} out of bounds for {} rows",index,self.rows); + assert_eq!(row.len(),self.cols,"row length mismatch: expected {} (cols), got {}",self.cols,row.len()); for (c, value) in row.into_iter().enumerate() { self.data.insert(c * (self.rows + 1) + index, value); @@ -261,14 +265,16 @@ impl<'a, T> MatrixRow<'a, T> { /// Specifies the axis along which to perform a reduction operation. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Axis { - /// Apply reduction along columns (vertical axis). + /// Apply reduction along columns (vertical axis). Result has 1 row. Col, - /// Apply reduction along rows (horizontal axis). + /// Apply reduction along rows (horizontal axis). Result has 1 column. Row, } +// --- Broadcasting --- + /// A trait to turn either a `Matrix` or a scalar T into a `Vec` of -/// length `rows*cols` (broadcasting the scalar). +/// length `rows*cols` (broadcasting the scalar). Used for comparisons. pub trait Broadcastable { fn to_vec(&self, rows: usize, cols: usize) -> Vec; }