Compare commits

..

No commits in common. "f91ddc5c4109825e8d5d1621f5c80e2a8b735051" and "3c8b69d3ab873fbd5b1f6f15a858093e9656d3a7" have entirely different histories.

View File

@ -1,6 +1,4 @@
//! A simple column-major Matrix implementation with element-wise operations. use std::ops::{Index, IndexMut, Not};
use std::ops::{Add, BitAnd, BitOr, BitXor, Div, Index, IndexMut, Mul, Not, Sub};
/// A columnmajor 2D matrix of `T` /// A columnmajor 2D matrix of `T`
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
@ -15,15 +13,9 @@ impl<T: Clone> Matrix<T> {
pub fn from_cols(cols_data: Vec<Vec<T>>) -> Self { pub fn from_cols(cols_data: Vec<Vec<T>>) -> Self {
let cols = cols_data.len(); let cols = cols_data.len();
assert!(cols > 0, "need at least one column"); assert!(cols > 0, "need at least one column");
// Handle empty cols_data let rows = cols_data[0].len();
let rows = cols_data.get(0).map_or(0, |c| c.len()); assert!(rows > 0, "need at least one row");
// Allow 0-row matrices if columns are empty, but not 0-col matrices if rows > 0 for (i, col) in cols_data.iter().enumerate().skip(1) {
assert!(
rows > 0 || cols == 0,
"need at least one row if columns exist"
);
for (i, col) in cols_data.iter().enumerate() {
assert!( assert!(
col.len() == rows, col.len() == rows,
"col {} has len {}, expected {}", "col {} has len {}, expected {}",
@ -32,34 +24,17 @@ impl<T: Clone> Matrix<T> {
rows rows
); );
} }
// Flatten column data directly let mut data = Vec::with_capacity(rows * cols);
let data = cols_data.into_iter().flatten().collect(); for col in cols_data {
data.extend(col);
}
Matrix { rows, cols, data } Matrix { rows, cols, data }
} }
/// Build from a flat Vec, assuming column-major order.
pub fn from_vec(data: Vec<T>, rows: usize, cols: usize) -> Self { pub fn from_vec(data: Vec<T>, rows: usize, cols: usize) -> Self {
assert!( assert!(rows > 0, "need at least one row");
rows > 0 || cols == 0, assert!(cols > 0, "need at least one column");
"need at least one row if columns exist" assert_eq!(data.len(), rows * cols, "data length mismatch");
);
assert!(
cols > 0 || rows == 0,
"need at least one column if rows exist"
);
if rows * cols != 0 {
// Only assert length if matrix is non-empty
assert_eq!(
data.len(),
rows * cols,
"data length mismatch: expected {}, got {}",
rows * cols,
data.len()
);
} else {
assert!(data.is_empty(), "data must be empty for 0-sized matrix");
}
Matrix { rows, cols, data } Matrix { rows, cols, data }
} }
@ -71,13 +46,7 @@ impl<T: Clone> Matrix<T> {
&mut self.data &mut self.data
} }
/// Consumes the Matrix and returns its underlying data Vec. pub fn as_vec(&self) -> Vec<T> {
pub fn into_vec(self) -> Vec<T> {
self.data
}
/// Creates a new Vec<T> containing the matrix data (cloned).
pub fn to_vec(&self) -> Vec<T> {
self.data.clone() self.data.clone()
} }
@ -89,36 +58,20 @@ impl<T: Clone> Matrix<T> {
self.cols self.cols
} }
/// Get element reference (immutable). Panics on out-of-bounds.
pub fn get(&self, r: usize, c: usize) -> &T { pub fn get(&self, r: usize, c: usize) -> &T {
&self[(r, c)] &self[(r, c)]
} }
/// Get element reference (mutable). Panics on out-of-bounds.
pub fn get_mut(&mut self, r: usize, c: usize) -> &mut T { pub fn get_mut(&mut self, r: usize, c: usize) -> &mut T {
&mut self[(r, c)] &mut self[(r, c)]
} }
#[inline] #[inline]
pub fn column(&self, c: usize) -> &[T] { pub fn column(&self, c: usize) -> &[T] {
assert!(
c < self.cols,
"column index {} out of bounds for {} columns",
c,
self.cols
);
let start = c * self.rows; let start = c * self.rows;
&self.data[start..start + self.rows] &self.data[start..start + self.rows]
} }
#[inline] #[inline]
pub fn column_mut(&mut self, c: usize) -> &mut [T] { pub fn column_mut(&mut self, c: usize) -> &mut [T] {
assert!(
c < self.cols,
"column index {} out of bounds for {} columns",
c,
self.cols
);
let start = c * self.rows; let start = c * self.rows;
&mut self.data[start..start + self.rows] &mut self.data[start..start + self.rows]
} }
@ -134,19 +87,11 @@ impl<T: Clone> Matrix<T> {
}) })
} }
/// Swaps two columns in the matrix. Panics on out-of-bounds. /// Swaps two columns in the matrix.
pub fn swap_columns(&mut self, c1: usize, c2: usize) { pub fn swap_columns(&mut self, c1: usize, c2: usize) {
assert!( assert!(
c1 < self.cols, c1 < self.cols && c2 < self.cols,
"column index c1={} out of bounds for {} columns", "column index out of bounds"
c1,
self.cols
);
assert!(
c2 < self.cols,
"column index c2={} out of bounds for {} columns",
c2,
self.cols
); );
if c1 == c2 { if c1 == c2 {
// Indices are equal; no operation required // Indices are equal; no operation required
@ -187,7 +132,7 @@ impl<T: Clone> Matrix<T> {
impl<T: Clone> Matrix<T> { impl<T: Clone> Matrix<T> {
/// Adds a column to the matrix at the specified index. /// Adds a column to the matrix at the specified index.
pub fn add_column(&mut self, index: usize, column: Vec<T>) { pub fn add_column(&mut self, index: usize, column: Vec<T>) {
assert!(index <= self.cols,"add_column index {} out of bounds for {} columns",index,self.cols); assert!(index <= self.cols, "column index out of bounds");
assert_eq!(column.len(), self.rows, "column length mismatch"); assert_eq!(column.len(), self.rows, "column length mismatch");
for (r, value) in column.into_iter().enumerate() { for (r, value) in column.into_iter().enumerate() {
@ -198,8 +143,8 @@ impl<T: Clone> Matrix<T> {
/// Adds a row to the matrix at the specified index. /// Adds a row to the matrix at the specified index.
pub fn add_row(&mut self, index: usize, row: Vec<T>) { pub fn add_row(&mut self, index: usize, row: Vec<T>) {
assert!(index <= self.rows,"add_row index {} out of bounds for {} rows",index,self.rows); assert!(index <= self.rows, "row index out of bounds");
assert_eq!(row.len(),self.cols,"row length mismatch: expected {} (cols), got {}",self.cols,row.len()); assert_eq!(row.len(), self.cols, "row length mismatch");
for (c, value) in row.into_iter().enumerate() { for (c, value) in row.into_iter().enumerate() {
self.data.insert(c * (self.rows + 1) + index, value); self.data.insert(c * (self.rows + 1) + index, value);
@ -214,14 +159,7 @@ impl<T> Index<(usize, usize)> for Matrix<T> {
#[inline] #[inline]
fn index(&self, (r, c): (usize, usize)) -> &T { fn index(&self, (r, c): (usize, usize)) -> &T {
// Validate that the requested indices are within bounds // Validate that the requested indices are within bounds
assert!( assert!(r < self.rows && c < self.cols, "index out of bounds");
r < self.rows && c < self.cols,
"index out of bounds: ({}, {}) vs {}x{}",
r,
c,
self.rows,
self.cols
);
// Compute column-major offset and return reference // Compute column-major offset and return reference
&self.data[c * self.rows + r] &self.data[c * self.rows + r]
} }
@ -231,14 +169,7 @@ impl<T> IndexMut<(usize, usize)> for Matrix<T> {
#[inline] #[inline]
fn index_mut(&mut self, (r, c): (usize, usize)) -> &mut T { fn index_mut(&mut self, (r, c): (usize, usize)) -> &mut T {
// Validate that the requested indices are within bounds // Validate that the requested indices are within bounds
assert!( assert!(r < self.rows && c < self.cols, "index out of bounds");
r < self.rows && c < self.cols,
"index out of bounds: ({}, {}) vs {}x{}",
r,
c,
self.rows,
self.cols
);
// Compute column-major offset and return mutable reference // Compute column-major offset and return mutable reference
&mut self.data[c * self.rows + r] &mut self.data[c * self.rows + r]
} }
@ -265,16 +196,14 @@ impl<'a, T> MatrixRow<'a, T> {
/// Specifies the axis along which to perform a reduction operation. /// Specifies the axis along which to perform a reduction operation.
#[derive(Clone, Copy, Debug, PartialEq, Eq)] #[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Axis { pub enum Axis {
/// Apply reduction along columns (vertical axis). Result has 1 row. /// Apply reduction along columns (vertical axis).
Col, Col,
/// Apply reduction along rows (horizontal axis). Result has 1 column. /// Apply reduction along rows (horizontal axis).
Row, Row,
} }
// --- Broadcasting ---
/// A trait to turn either a `Matrix<T>` or a scalar T into a `Vec<T>` of /// A trait to turn either a `Matrix<T>` or a scalar T into a `Vec<T>` of
/// length `rows*cols` (broadcasting the scalar). Used for comparisons. /// length `rows*cols` (broadcasting the scalar).
pub trait Broadcastable<T> { pub trait Broadcastable<T> {
fn to_vec(&self, rows: usize, cols: usize) -> Vec<T>; fn to_vec(&self, rows: usize, cols: usize) -> Vec<T>;
} }