From 6c82f40b3aa5b2049cac66881f899e8a4edc0107 Mon Sep 17 00:00:00 2001 From: Palash Tyagi <23239946+Magnus167@users.noreply.github.com> Date: Thu, 17 Apr 2025 11:56:14 +0100 Subject: [PATCH] add first draft of a matrix implementation --- src/frame/mat.rs | 272 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 272 insertions(+) create mode 100644 src/frame/mat.rs diff --git a/src/frame/mat.rs b/src/frame/mat.rs new file mode 100644 index 0000000..529eebd --- /dev/null +++ b/src/frame/mat.rs @@ -0,0 +1,272 @@ +use std::ops::{Index, IndexMut}; + +/// A column‑major 2D matrix of `T` +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Matrix { + rows: usize, + cols: usize, + data: Vec, +} + +impl Matrix { + /// Build from columns (each inner Vec is one column) + pub fn from_cols(cols_data: Vec>) -> Self { + let cols = cols_data.len(); + assert!(cols > 0, "need at least one column"); + let rows = cols_data[0].len(); + assert!(rows > 0, "need at least one row"); + for (i, col) in cols_data.iter().enumerate().skip(1) { + assert!( + col.len() == rows, + "col {} has len {}, expected {}", + i, + col.len(), + rows + ); + } + let mut data = Vec::with_capacity(rows * cols); + for col in cols_data { + data.extend(col); + } + Matrix { rows, cols, data } + } + + pub fn from_vec(data: Vec, rows: usize, cols: usize) -> Self { + assert!(rows > 0, "need at least one row"); + assert!(cols > 0, "need at least one column"); + assert_eq!(data.len(), rows * cols, "data length mismatch"); + Matrix { rows, cols, data } + } + + pub fn rows(&self) -> usize { + self.rows + } + pub fn cols(&self) -> usize { + self.cols + } + + pub fn get(&self, r: usize, c: usize) -> &T { + &self[(r, c)] + } + pub fn get_mut(&mut self, r: usize, c: usize) -> &mut T { + &mut self[(r, c)] + } + + #[inline] + pub fn column(&self, c: usize) -> &[T] { + let start = c * self.rows; + &self.data[start..start + self.rows] + } + + pub fn iter_columns(&self) -> impl Iterator { + (0..self.cols).map(move |c| self.column(c)) + } + + pub fn iter_rows(&self) -> impl Iterator> { + (0..self.rows).map(move |r| Row { + matrix: self, + row: r, + }) + } + + /// Swaps two columns in the matrix. + pub fn swap_columns(&mut self, c1: usize, c2: usize) { + assert!( + c1 < self.cols && c2 < self.cols, + "column index out of bounds" + ); + if c1 == c2 { + return; + } + + for r in 0..self.rows { + self.data.swap(c1 * self.rows + r, c2 * self.rows + r); + } + } +} + +impl Index<(usize, usize)> for Matrix { + type Output = T; + #[inline] + fn index(&self, (r, c): (usize, usize)) -> &T { + assert!(r < self.rows && c < self.cols, "index out of bounds"); + &self.data[c * self.rows + r] + } +} +impl IndexMut<(usize, usize)> for Matrix { + #[inline] + fn index_mut(&mut self, (r, c): (usize, usize)) -> &mut T { + assert!(r < self.rows && c < self.cols, "index out of bounds"); + &mut self.data[c * self.rows + r] + } +} + +/// A view of one row +pub struct Row<'a, T> { + matrix: &'a Matrix, + row: usize, +} +impl<'a, T> Row<'a, T> { + pub fn get(&self, c: usize) -> &T { + &self.matrix[(self.row, c)] + } + pub fn iter(&self) -> impl Iterator { + (0..self.matrix.cols).map(move |c| &self.matrix[(self.row, c)]) + } +} + +/// Macro to generate element‐wise impls for +, -, *, / +macro_rules! impl_elementwise_op { + ($OpTrait:ident, $method:ident, $op:tt) => { + impl<'a, 'b, T> std::ops::$OpTrait<&'b Matrix> for &'a Matrix + where + T: Clone + std::ops::$OpTrait, + { + type Output = Matrix; + + fn $method(self, rhs: &'b Matrix) -> Matrix { + assert_eq!(self.rows, rhs.rows, "row count mismatch"); + assert_eq!(self.cols, rhs.cols, "col count mismatch"); + let data = self + .data + .iter() + .cloned() + .zip(rhs.data.iter().cloned()) + .map(|(a, b)| a $op b) + .collect(); + Matrix { rows: self.rows, cols: self.cols, data } + } + } + }; +} + +// invoke it 4 times: +impl_elementwise_op!(Add, add, +); +impl_elementwise_op!(Sub, sub, -); +impl_elementwise_op!(Mul, mul, *); +impl_elementwise_op!(Div, div, /); + +// === New code begins here ===================================================== +/// Axis along which to apply a reduction. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum Axis { + /// Operate column‑wise (vertical). + Col, + /// Operate row‑wise (horizontal). + Row, +} + +pub type FloatMatrix = Matrix; +pub type FloatVector = Vec; +pub type BoolMatrix = Matrix; +pub type IntMatrix = Matrix; + +impl Matrix { + /// Apply a function along *columns* and collect its result in a `Vec`. + /// This is very fast because each column is contiguous in memory. + #[inline] + fn apply_colwise(&self, mut f: F) -> Vec + where + F: FnMut(&[f64]) -> U, + { + let mut out = Vec::with_capacity(self.cols); + for c in 0..self.cols { + out.push(f(self.column(c))); + } + out + } + + /// Apply a function along *rows* and collect its result in a `Vec`. + /// Slower than the column version because data are not contiguous, but a single + /// reusable buffer is used to minimize allocations. + #[inline] + fn apply_rowwise(&self, mut f: F) -> Vec + where + F: FnMut(&[f64]) -> U, + { + let mut out = Vec::with_capacity(self.rows); + // Re‑use one buffer for all rows to avoid repeated allocations. + let mut buf = vec![0.0f64; self.cols]; + for r in 0..self.rows { + for c in 0..self.cols { + buf[c] = self[(r, c)]; + } + out.push(f(&buf)); + } + out + } + + /// Generic helper that dispatches to [`Matrix::apply_colwise`] or + /// [`Matrix::apply_rowwise`] depending on `axis`. + #[inline] + pub fn apply_axis(&self, axis: Axis, f: F) -> Vec + where + F: FnMut(&[f64]) -> U, + { + match axis { + Axis::Col => self.apply_colwise(f), + Axis::Row => self.apply_rowwise(f), + } + } + + // --------------------------------------------------------------------- + // Convenience reductions built on top of `apply_axis`. + // By convention "vertical" = column‑wise, "horizontal" = row‑wise. + // --------------------------------------------------------------------- + + /// Column‑wise sum, ignoring `NaN`s. + pub fn sum_vertical(&self) -> FloatVector { + self.apply_colwise(|col| col.iter().copied().filter(|v| !v.is_nan()).sum()) + } + + /// Row‑wise sum, ignoring `NaN`s. + pub fn sum_horizontal(&self) -> FloatVector { + self.apply_rowwise(|row| row.iter().copied().filter(|v| !v.is_nan()).sum()) + } + + /// Column‑wise product, ignoring `NaN`s. + pub fn prod_vertical(&self) -> FloatVector { + self.apply_colwise(|col| { + col.iter() + .copied() + .filter(|v| !v.is_nan()) + .fold(1.0, |acc, x| acc * x) + }) + } + + /// Row‑wise product, ignoring `NaN`s. + pub fn prod_horizontal(&self) -> FloatVector { + self.apply_rowwise(|row| { + row.iter() + .copied() + .filter(|v| !v.is_nan()) + .fold(1.0, |acc, x| acc * x) + }) + } + + /// Column‑wise count of `NaN`s. + pub fn count_nan_vertical(&self) -> Vec { + self.apply_colwise(|col| col.iter().filter(|x| x.is_nan()).count()) + } + + /// Row‑wise count of `NaN`s. + pub fn count_nan_horizontal(&self) -> Vec { + self.apply_rowwise(|row| row.iter().filter(|x| x.is_nan()).count()) + } + + // --------------------------------------------------------------------- + // Existing helpers + // --------------------------------------------------------------------- + + pub fn is_nan(&self) -> BoolMatrix { + let mut data = Vec::with_capacity(self.rows * self.cols); + for r in 0..self.rows { + for c in 0..self.cols { + data.push(self[(r, c)].is_nan()); + } + } + BoolMatrix::from_vec(data, self.rows, self.cols) + } + + +}