From 4f21290b1c2ce47aee569554e7f158f660f6ac21 Mon Sep 17 00:00:00 2001 From: Palash Tyagi <23239946+Magnus167@users.noreply.github.com> Date: Fri, 18 Apr 2025 22:36:46 +0100 Subject: [PATCH] add SeriesOps trait and implementations for FloatMatrix to support axis-based operations --- src/matrix/seriesops.rs | 143 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 143 insertions(+) create mode 100644 src/matrix/seriesops.rs diff --git a/src/matrix/seriesops.rs b/src/matrix/seriesops.rs new file mode 100644 index 0000000..b6853d3 --- /dev/null +++ b/src/matrix/seriesops.rs @@ -0,0 +1,143 @@ +use crate::matrix::{Axis, BoolMatrix, FloatMatrix}; + +/// “Series–like” helpers that work along a single axis. +/// +/// *All* the old methods (`sum_*`, `prod_*`, `is_nan`, …) are exposed +/// through this trait, so nothing needs to stay on an `impl Matrix`; +/// just `use SeriesOps` to make the extension methods available. +pub trait SeriesOps { + /// Generic helper: apply `f` to every column/row and collect its + /// result in a `Vec`. + fn apply_axis(&self, axis: Axis, f: F) -> Vec + where + F: FnMut(&[f64]) -> U; + + fn sum_vertical(&self) -> Vec; + fn sum_horizontal(&self) -> Vec; + + fn prod_vertical(&self) -> Vec; + fn prod_horizontal(&self) -> Vec; + + fn cumsum_vertical(&self) -> FloatMatrix; + fn cumsum_horizontal(&self) -> FloatMatrix; + + fn count_nan_vertical(&self) -> Vec; + fn count_nan_horizontal(&self) -> Vec; + + fn is_nan(&self) -> BoolMatrix; +} + +impl SeriesOps for FloatMatrix { + fn apply_axis(&self, axis: Axis, mut f: F) -> Vec + where + F: FnMut(&[f64]) -> U, + { + match axis { + Axis::Col => { + let mut out = Vec::with_capacity(self.cols()); + for c in 0..self.cols() { + out.push(f(self.column(c))); + } + out + } + Axis::Row => { + let mut out = Vec::with_capacity(self.rows()); + let mut buf = vec![0.0; self.cols()]; // reusable buffer + for r in 0..self.rows() { + for c in 0..self.cols() { + buf[c] = self[(r, c)]; + } + out.push(f(&buf)); + } + out + } + } + } + + fn sum_vertical(&self) -> Vec { + self.apply_axis(Axis::Col, |col| { + col.iter().copied().filter(|v| !v.is_nan()).sum::() + }) + } + + fn sum_horizontal(&self) -> Vec { + self.apply_axis(Axis::Row, |row| { + row.iter().copied().filter(|v| !v.is_nan()).sum::() + }) + } + + fn prod_vertical(&self) -> Vec { + self.apply_axis(Axis::Col, |col| { + col.iter() + .copied() + .filter(|v| !v.is_nan()) + .fold(1.0, |acc, x| acc * x) + }) + } + + fn prod_horizontal(&self) -> Vec { + self.apply_axis(Axis::Row, |row| { + row.iter() + .copied() + .filter(|v| !v.is_nan()) + .fold(1.0, |acc, x| acc * x) + }) + } + + fn cumsum_vertical(&self) -> FloatMatrix { + let mut data = Vec::with_capacity(self.rows() * self.cols()); + for c in 0..self.cols() { + let mut acc = 0.0; + for r in 0..self.rows() { + let v = self[(r, c)]; + if !v.is_nan() { + acc += v; + } + data.push(acc); + } + } + FloatMatrix::from_vec(data, self.rows(), self.cols()) + } + + fn cumsum_horizontal(&self) -> FloatMatrix { + // 1. Store row-wise cumulative sums temporarily + let mut row_results: Vec> = Vec::with_capacity(self.rows()); + for r in 0..self.rows() { + let mut row_data = Vec::with_capacity(self.cols()); + let mut acc = 0.0; + for c in 0..self.cols() { + let v = self[(r, c)]; + if !v.is_nan() { + acc += v; + } + row_data.push(acc); + } + row_results.push(row_data); + } + + // 2. Build the final data vector in column-major order + let mut final_data = Vec::with_capacity(self.rows() * self.cols()); + for c in 0..self.cols() { + for r in 0..self.rows() { + // Get the element from row 'r', column 'c' of the row_results + final_data.push(row_results[r][c]); + } + } + + // 3. Construct the matrix using the correctly ordered data + FloatMatrix::from_vec(final_data, self.rows(), self.cols()) + } + + fn count_nan_vertical(&self) -> Vec { + self.apply_axis(Axis::Col, |col| col.iter().filter(|x| x.is_nan()).count()) + } + + fn count_nan_horizontal(&self) -> Vec { + self.apply_axis(Axis::Row, |row| row.iter().filter(|x| x.is_nan()).count()) + } + + fn is_nan(&self) -> BoolMatrix { + let data = self.data().iter().map(|v| v.is_nan()).collect(); + BoolMatrix::from_vec(data, self.rows(), self.cols()) + } +}