Improve assertions and documentation for matrix operations; allow 0-row matrices from empty columns.

This commit is contained in:
Palash Tyagi 2025-05-01 23:56:06 +01:00
parent ba1e2b3d43
commit f91ddc5c41

View File

@ -89,9 +89,12 @@ impl<T: Clone> Matrix<T> {
self.cols
}
/// Get element reference (immutable). Panics on out-of-bounds.
pub fn get(&self, r: usize, c: usize) -> &T {
&self[(r, c)]
}
/// Get element reference (mutable). Panics on out-of-bounds.
pub fn get_mut(&mut self, r: usize, c: usize) -> &mut T {
&mut self[(r, c)]
}
@ -107,6 +110,7 @@ impl<T: Clone> Matrix<T> {
let start = c * self.rows;
&self.data[start..start + self.rows]
}
#[inline]
pub fn column_mut(&mut self, c: usize) -> &mut [T] {
assert!(
@ -130,7 +134,7 @@ impl<T: Clone> Matrix<T> {
})
}
/// Swaps two columns in the matrix.
/// Swaps two columns in the matrix. Panics on out-of-bounds.
pub fn swap_columns(&mut self, c1: usize, c2: usize) {
assert!(
c1 < self.cols,
@ -183,7 +187,7 @@ impl<T: Clone> Matrix<T> {
impl<T: Clone> Matrix<T> {
/// Adds a column to the matrix at the specified index.
pub fn add_column(&mut self, index: usize, column: Vec<T>) {
assert!(index <= self.cols, "add_column index {} out of bounds for {} columns", index, self.cols);
assert!(index <= self.cols,"add_column index {} out of bounds for {} columns",index,self.cols);
assert_eq!(column.len(), self.rows, "column length mismatch");
for (r, value) in column.into_iter().enumerate() {
@ -194,8 +198,8 @@ impl<T: Clone> Matrix<T> {
/// Adds a row to the matrix at the specified index.
pub fn add_row(&mut self, index: usize, row: Vec<T>) {
assert!(index <= self.rows, "add_row index {} out of bounds for {} rows", index, self.rows);
assert_eq!(row.len(), self.cols, "row length mismatch");
assert!(index <= self.rows,"add_row index {} out of bounds for {} rows",index,self.rows);
assert_eq!(row.len(),self.cols,"row length mismatch: expected {} (cols), got {}",self.cols,row.len());
for (c, value) in row.into_iter().enumerate() {
self.data.insert(c * (self.rows + 1) + index, value);
@ -261,14 +265,16 @@ impl<'a, T> MatrixRow<'a, T> {
/// Specifies the axis along which to perform a reduction operation.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Axis {
/// Apply reduction along columns (vertical axis).
/// Apply reduction along columns (vertical axis). Result has 1 row.
Col,
/// Apply reduction along rows (horizontal axis).
/// Apply reduction along rows (horizontal axis). Result has 1 column.
Row,
}
// --- Broadcasting ---
/// A trait to turn either a `Matrix<T>` or a scalar T into a `Vec<T>` of
/// length `rows*cols` (broadcasting the scalar).
/// length `rows*cols` (broadcasting the scalar). Used for comparisons.
pub trait Broadcastable<T> {
fn to_vec(&self, rows: usize, cols: usize) -> Vec<T>;
}