Improve assertions and documentation for matrix operations; allow 0-row matrices from empty columns.

This commit is contained in:
Palash Tyagi 2025-05-01 23:56:06 +01:00
parent ba1e2b3d43
commit f91ddc5c41

View File

@ -89,9 +89,12 @@ impl<T: Clone> Matrix<T> {
self.cols self.cols
} }
/// Get element reference (immutable). Panics on out-of-bounds.
pub fn get(&self, r: usize, c: usize) -> &T { pub fn get(&self, r: usize, c: usize) -> &T {
&self[(r, c)] &self[(r, c)]
} }
/// Get element reference (mutable). Panics on out-of-bounds.
pub fn get_mut(&mut self, r: usize, c: usize) -> &mut T { pub fn get_mut(&mut self, r: usize, c: usize) -> &mut T {
&mut self[(r, c)] &mut self[(r, c)]
} }
@ -107,6 +110,7 @@ impl<T: Clone> Matrix<T> {
let start = c * self.rows; let start = c * self.rows;
&self.data[start..start + self.rows] &self.data[start..start + self.rows]
} }
#[inline] #[inline]
pub fn column_mut(&mut self, c: usize) -> &mut [T] { pub fn column_mut(&mut self, c: usize) -> &mut [T] {
assert!( assert!(
@ -130,7 +134,7 @@ impl<T: Clone> Matrix<T> {
}) })
} }
/// Swaps two columns in the matrix. /// Swaps two columns in the matrix. Panics on out-of-bounds.
pub fn swap_columns(&mut self, c1: usize, c2: usize) { pub fn swap_columns(&mut self, c1: usize, c2: usize) {
assert!( assert!(
c1 < self.cols, c1 < self.cols,
@ -195,7 +199,7 @@ impl<T: Clone> Matrix<T> {
/// Adds a row to the matrix at the specified index. /// Adds a row to the matrix at the specified index.
pub fn add_row(&mut self, index: usize, row: Vec<T>) { pub fn add_row(&mut self, index: usize, row: Vec<T>) {
assert!(index <= self.rows,"add_row index {} out of bounds for {} rows",index,self.rows); assert!(index <= self.rows,"add_row index {} out of bounds for {} rows",index,self.rows);
assert_eq!(row.len(), self.cols, "row length mismatch"); assert_eq!(row.len(),self.cols,"row length mismatch: expected {} (cols), got {}",self.cols,row.len());
for (c, value) in row.into_iter().enumerate() { for (c, value) in row.into_iter().enumerate() {
self.data.insert(c * (self.rows + 1) + index, value); self.data.insert(c * (self.rows + 1) + index, value);
@ -261,14 +265,16 @@ impl<'a, T> MatrixRow<'a, T> {
/// Specifies the axis along which to perform a reduction operation. /// Specifies the axis along which to perform a reduction operation.
#[derive(Clone, Copy, Debug, PartialEq, Eq)] #[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Axis { pub enum Axis {
/// Apply reduction along columns (vertical axis). /// Apply reduction along columns (vertical axis). Result has 1 row.
Col, Col,
/// Apply reduction along rows (horizontal axis). /// Apply reduction along rows (horizontal axis). Result has 1 column.
Row, Row,
} }
// --- Broadcasting ---
/// A trait to turn either a `Matrix<T>` or a scalar T into a `Vec<T>` of /// A trait to turn either a `Matrix<T>` or a scalar T into a `Vec<T>` of
/// length `rows*cols` (broadcasting the scalar). /// length `rows*cols` (broadcasting the scalar). Used for comparisons.
pub trait Broadcastable<T> { pub trait Broadcastable<T> {
fn to_vec(&self, rows: usize, cols: usize) -> Vec<T>; fn to_vec(&self, rows: usize, cols: usize) -> Vec<T>;
} }