mirror of
https://github.com/Magnus167/rustframe.git
synced 2025-08-20 09:30:01 +00:00
add Matrix<T> struct with core functionality for 2D matrix operations
This commit is contained in:
parent
fd274ebb6d
commit
a161cf0c76
@ -1,4 +1,4 @@
|
||||
use std::ops::{Index, IndexMut};
|
||||
use std::ops::{Index, IndexMut, Not};
|
||||
|
||||
/// A column‑major 2D matrix of `T`
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
@ -41,6 +41,10 @@ impl<T> Matrix<T> {
|
||||
pub fn rows(&self) -> usize {
|
||||
self.rows
|
||||
}
|
||||
pub fn data(&self) -> &[T] {
|
||||
&self.data
|
||||
}
|
||||
|
||||
pub fn cols(&self) -> usize {
|
||||
self.cols
|
||||
}
|
||||
@ -57,13 +61,18 @@ impl<T> Matrix<T> {
|
||||
let start = c * self.rows;
|
||||
&self.data[start..start + self.rows]
|
||||
}
|
||||
#[inline]
|
||||
pub fn column_mut(&mut self, c: usize) -> &mut [T] {
|
||||
let start = c * self.rows;
|
||||
&mut self.data[start..start + self.rows]
|
||||
}
|
||||
|
||||
pub fn iter_columns(&self) -> impl Iterator<Item = &[T]> {
|
||||
(0..self.cols).map(move |c| self.column(c))
|
||||
}
|
||||
|
||||
pub fn iter_rows(&self) -> impl Iterator<Item = Row<'_, T>> {
|
||||
(0..self.rows).map(move |r| Row {
|
||||
pub fn iter_rows(&self) -> impl Iterator<Item = MatrixRow<'_, T>> {
|
||||
(0..self.rows).map(move |r| MatrixRow {
|
||||
matrix: self,
|
||||
row: r,
|
||||
})
|
||||
@ -83,6 +92,48 @@ impl<T> Matrix<T> {
|
||||
self.data.swap(c1 * self.rows + r, c2 * self.rows + r);
|
||||
}
|
||||
}
|
||||
|
||||
/// Deletes a column from the matrix.
|
||||
pub fn delete_column(&mut self, col: usize) {
|
||||
assert!(col < self.cols, "column index out of bounds");
|
||||
for r in (0..self.rows).rev() {
|
||||
self.data.remove(col * self.rows + r);
|
||||
}
|
||||
self.cols -= 1;
|
||||
}
|
||||
|
||||
/// Deletes a row from the matrix.
|
||||
pub fn delete_row(&mut self, row: usize) {
|
||||
assert!(row < self.rows, "row index out of bounds");
|
||||
for c in (0..self.cols).rev() {
|
||||
self.data.remove(c * self.rows + row);
|
||||
}
|
||||
self.rows -= 1;
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Clone> Matrix<T> {
|
||||
/// Adds a column to the matrix at the specified index.
|
||||
pub fn add_column(&mut self, index: usize, column: Vec<T>) {
|
||||
assert!(index <= self.cols, "column index out of bounds");
|
||||
assert_eq!(column.len(), self.rows, "column length mismatch");
|
||||
|
||||
for (r, value) in column.into_iter().enumerate() {
|
||||
self.data.insert(index * self.rows + r, value);
|
||||
}
|
||||
self.cols += 1;
|
||||
}
|
||||
|
||||
/// Adds a row to the matrix at the specified index.
|
||||
pub fn add_row(&mut self, index: usize, row: Vec<T>) {
|
||||
assert!(index <= self.rows, "row index out of bounds");
|
||||
assert_eq!(row.len(), self.cols, "row length mismatch");
|
||||
|
||||
for (c, value) in row.into_iter().enumerate() {
|
||||
self.data.insert(c * (self.rows + 1) + index, value);
|
||||
}
|
||||
self.rows += 1;
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Index<(usize, usize)> for Matrix<T> {
|
||||
@ -102,11 +153,11 @@ impl<T> IndexMut<(usize, usize)> for Matrix<T> {
|
||||
}
|
||||
|
||||
/// A view of one row
|
||||
pub struct Row<'a, T> {
|
||||
pub struct MatrixRow<'a, T> {
|
||||
matrix: &'a Matrix<T>,
|
||||
row: usize,
|
||||
}
|
||||
impl<'a, T> Row<'a, T> {
|
||||
impl<'a, T> MatrixRow<'a, T> {
|
||||
pub fn get(&self, c: usize) -> &T {
|
||||
&self.matrix[(self.row, c)]
|
||||
}
|
||||
@ -146,7 +197,50 @@ impl_elementwise_op!(Sub, sub, -);
|
||||
impl_elementwise_op!(Mul, mul, *);
|
||||
impl_elementwise_op!(Div, div, /);
|
||||
|
||||
// === New code begins here =====================================================
|
||||
pub type FloatMatrix = Matrix<f64>;
|
||||
pub type BoolMatrix = Matrix<bool>;
|
||||
pub type IntMatrix = Matrix<i32>;
|
||||
pub type StringMatrix = Matrix<String>;
|
||||
|
||||
// implement bit ops - and, or, xor, not -- using Macros
|
||||
|
||||
macro_rules! impl_bitwise_op {
|
||||
($OpTrait:ident, $method:ident, $op:tt) => {
|
||||
impl<'a, 'b> std::ops::$OpTrait<&'b Matrix<bool>> for &'a Matrix<bool> {
|
||||
type Output = Matrix<bool>;
|
||||
|
||||
fn $method(self, rhs: &'b Matrix<bool>) -> Matrix<bool> {
|
||||
assert_eq!(self.rows, rhs.rows, "row count mismatch");
|
||||
assert_eq!(self.cols, rhs.cols, "col count mismatch");
|
||||
let data = self
|
||||
.data
|
||||
.iter()
|
||||
.cloned()
|
||||
.zip(rhs.data.iter().cloned())
|
||||
.map(|(a, b)| a $op b)
|
||||
.collect();
|
||||
Matrix { rows: self.rows, cols: self.cols, data }
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
impl_bitwise_op!(BitAnd, bitand, &);
|
||||
impl_bitwise_op!(BitOr, bitor, |);
|
||||
impl_bitwise_op!(BitXor, bitxor, ^);
|
||||
|
||||
impl Not for Matrix<bool> {
|
||||
type Output = Matrix<bool>;
|
||||
|
||||
fn not(self) -> Matrix<bool> {
|
||||
let data = self.data.iter().map(|&v| !v).collect();
|
||||
Matrix {
|
||||
rows: self.rows,
|
||||
cols: self.cols,
|
||||
data,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Axis along which to apply a reduction.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum Axis {
|
||||
@ -155,118 +249,3 @@ pub enum Axis {
|
||||
/// Operate row‑wise (horizontal).
|
||||
Row,
|
||||
}
|
||||
|
||||
pub type FloatMatrix = Matrix<f64>;
|
||||
pub type FloatVector = Vec<f64>;
|
||||
pub type BoolMatrix = Matrix<bool>;
|
||||
pub type IntMatrix = Matrix<i32>;
|
||||
|
||||
impl Matrix<f64> {
|
||||
/// Apply a function along *columns* and collect its result in a `Vec`.
|
||||
/// This is very fast because each column is contiguous in memory.
|
||||
#[inline]
|
||||
fn apply_colwise<U, F>(&self, mut f: F) -> Vec<U>
|
||||
where
|
||||
F: FnMut(&[f64]) -> U,
|
||||
{
|
||||
let mut out = Vec::with_capacity(self.cols);
|
||||
for c in 0..self.cols {
|
||||
out.push(f(self.column(c)));
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// Apply a function along *rows* and collect its result in a `Vec`.
|
||||
/// Slower than the column version because data are not contiguous, but a single
|
||||
/// reusable buffer is used to minimize allocations.
|
||||
#[inline]
|
||||
fn apply_rowwise<U, F>(&self, mut f: F) -> Vec<U>
|
||||
where
|
||||
F: FnMut(&[f64]) -> U,
|
||||
{
|
||||
let mut out = Vec::with_capacity(self.rows);
|
||||
// Re‑use one buffer for all rows to avoid repeated allocations.
|
||||
let mut buf = vec![0.0f64; self.cols];
|
||||
for r in 0..self.rows {
|
||||
for c in 0..self.cols {
|
||||
buf[c] = self[(r, c)];
|
||||
}
|
||||
out.push(f(&buf));
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// Generic helper that dispatches to [`Matrix::apply_colwise`] or
|
||||
/// [`Matrix::apply_rowwise`] depending on `axis`.
|
||||
#[inline]
|
||||
pub fn apply_axis<U, F>(&self, axis: Axis, f: F) -> Vec<U>
|
||||
where
|
||||
F: FnMut(&[f64]) -> U,
|
||||
{
|
||||
match axis {
|
||||
Axis::Col => self.apply_colwise(f),
|
||||
Axis::Row => self.apply_rowwise(f),
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------
|
||||
// Convenience reductions built on top of `apply_axis`.
|
||||
// By convention "vertical" = column‑wise, "horizontal" = row‑wise.
|
||||
// ---------------------------------------------------------------------
|
||||
|
||||
/// Column‑wise sum, ignoring `NaN`s.
|
||||
pub fn sum_vertical(&self) -> FloatVector {
|
||||
self.apply_colwise(|col| col.iter().copied().filter(|v| !v.is_nan()).sum())
|
||||
}
|
||||
|
||||
/// Row‑wise sum, ignoring `NaN`s.
|
||||
pub fn sum_horizontal(&self) -> FloatVector {
|
||||
self.apply_rowwise(|row| row.iter().copied().filter(|v| !v.is_nan()).sum())
|
||||
}
|
||||
|
||||
/// Column‑wise product, ignoring `NaN`s.
|
||||
pub fn prod_vertical(&self) -> FloatVector {
|
||||
self.apply_colwise(|col| {
|
||||
col.iter()
|
||||
.copied()
|
||||
.filter(|v| !v.is_nan())
|
||||
.fold(1.0, |acc, x| acc * x)
|
||||
})
|
||||
}
|
||||
|
||||
/// Row‑wise product, ignoring `NaN`s.
|
||||
pub fn prod_horizontal(&self) -> FloatVector {
|
||||
self.apply_rowwise(|row| {
|
||||
row.iter()
|
||||
.copied()
|
||||
.filter(|v| !v.is_nan())
|
||||
.fold(1.0, |acc, x| acc * x)
|
||||
})
|
||||
}
|
||||
|
||||
/// Column‑wise count of `NaN`s.
|
||||
pub fn count_nan_vertical(&self) -> Vec<usize> {
|
||||
self.apply_colwise(|col| col.iter().filter(|x| x.is_nan()).count())
|
||||
}
|
||||
|
||||
/// Row‑wise count of `NaN`s.
|
||||
pub fn count_nan_horizontal(&self) -> Vec<usize> {
|
||||
self.apply_rowwise(|row| row.iter().filter(|x| x.is_nan()).count())
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------
|
||||
// Existing helpers
|
||||
// ---------------------------------------------------------------------
|
||||
|
||||
pub fn is_nan(&self) -> BoolMatrix {
|
||||
let mut data = Vec::with_capacity(self.rows * self.cols);
|
||||
for r in 0..self.rows {
|
||||
for c in 0..self.cols {
|
||||
data.push(self[(r, c)].is_nan());
|
||||
}
|
||||
}
|
||||
BoolMatrix::from_vec(data, self.rows, self.cols)
|
||||
}
|
||||
|
||||
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user