rustframe/src/matrix/seriesops.rs
2025-04-20 01:46:57 +01:00

364 lines
14 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use crate::matrix::{Axis, BoolMatrix, FloatMatrix};
/// “Serieslike” helpers that work along a single axis.
///
/// *All* the old methods (`sum_*`, `prod_*`, `is_nan`, …) are exposed
/// through this trait, so nothing needs to stay on an `impl Matrix<f64>`;
/// just `use SeriesOps` to make the extension methods available.
pub trait SeriesOps {
/// Generic helper: apply `f` to every column/row and collect its
/// result in a `Vec`.
fn apply_axis<U, F>(&self, axis: Axis, f: F) -> Vec<U>
where
F: FnMut(&[f64]) -> U;
fn sum_vertical(&self) -> Vec<f64>;
fn sum_horizontal(&self) -> Vec<f64>;
fn prod_vertical(&self) -> Vec<f64>;
fn prod_horizontal(&self) -> Vec<f64>;
fn cumsum_vertical(&self) -> FloatMatrix;
fn cumsum_horizontal(&self) -> FloatMatrix;
fn count_nan_vertical(&self) -> Vec<usize>;
fn count_nan_horizontal(&self) -> Vec<usize>;
fn is_nan(&self) -> BoolMatrix;
}
impl SeriesOps for FloatMatrix {
fn apply_axis<U, F>(&self, axis: Axis, mut f: F) -> Vec<U>
where
F: FnMut(&[f64]) -> U,
{
match axis {
Axis::Col => {
let mut out = Vec::with_capacity(self.cols());
for c in 0..self.cols() {
out.push(f(self.column(c)));
}
out
}
Axis::Row => {
let mut out = Vec::with_capacity(self.rows());
let mut buf = vec![0.0; self.cols()]; // reusable buffer
for r in 0..self.rows() {
for c in 0..self.cols() {
buf[c] = self[(r, c)];
}
out.push(f(&buf));
}
out
}
}
}
fn sum_vertical(&self) -> Vec<f64> {
self.apply_axis(Axis::Col, |col| {
col.iter().copied().filter(|v| !v.is_nan()).sum::<f64>()
})
}
fn sum_horizontal(&self) -> Vec<f64> {
self.apply_axis(Axis::Row, |row| {
row.iter().copied().filter(|v| !v.is_nan()).sum::<f64>()
})
}
fn prod_vertical(&self) -> Vec<f64> {
self.apply_axis(Axis::Col, |col| {
col.iter()
.copied()
.filter(|v| !v.is_nan())
.fold(1.0, |acc, x| acc * x)
})
}
fn prod_horizontal(&self) -> Vec<f64> {
self.apply_axis(Axis::Row, |row| {
row.iter()
.copied()
.filter(|v| !v.is_nan())
.fold(1.0, |acc, x| acc * x)
})
}
fn cumsum_vertical(&self) -> FloatMatrix {
let mut data = Vec::with_capacity(self.rows() * self.cols());
for c in 0..self.cols() {
let mut acc = 0.0;
for r in 0..self.rows() {
let v = self[(r, c)];
if !v.is_nan() {
acc += v;
}
data.push(acc);
}
}
FloatMatrix::from_vec(data, self.rows(), self.cols())
}
fn cumsum_horizontal(&self) -> FloatMatrix {
// 1. Store row-wise cumulative sums temporarily
let mut row_results: Vec<Vec<f64>> = Vec::with_capacity(self.rows());
for r in 0..self.rows() {
let mut row_data = Vec::with_capacity(self.cols());
let mut acc = 0.0;
for c in 0..self.cols() {
let v = self[(r, c)];
if !v.is_nan() {
acc += v;
}
row_data.push(acc);
}
row_results.push(row_data);
}
// 2. Build the final data vector in column-major order
let mut final_data = Vec::with_capacity(self.rows() * self.cols());
for c in 0..self.cols() {
for r in 0..self.rows() {
// Get the element from row 'r', column 'c' of the row_results
final_data.push(row_results[r][c]);
}
}
// 3. Construct the matrix using the correctly ordered data
FloatMatrix::from_vec(final_data, self.rows(), self.cols())
}
fn count_nan_vertical(&self) -> Vec<usize> {
self.apply_axis(Axis::Col, |col| col.iter().filter(|x| x.is_nan()).count())
}
fn count_nan_horizontal(&self) -> Vec<usize> {
self.apply_axis(Axis::Row, |row| row.iter().filter(|x| x.is_nan()).count())
}
fn is_nan(&self) -> BoolMatrix {
let data = self.data().iter().map(|v| v.is_nan()).collect();
BoolMatrix::from_vec(data, self.rows(), self.cols())
}
}
#[cfg(test)]
mod tests {
use super::*;
// Helper function to create a FloatMatrix for SeriesOps testing
fn create_float_test_matrix() -> FloatMatrix {
// 3x3 matrix (column-major) with some NaNs
// 1.0 4.0 7.0
// 2.0 NaN 8.0
// 3.0 6.0 NaN
let data = vec![1.0, 2.0, 3.0, 4.0, f64::NAN, 6.0, 7.0, 8.0, f64::NAN];
FloatMatrix::from_vec(data, 3, 3)
}
// --- Tests for SeriesOps (FloatMatrix) ---
#[test]
fn test_series_ops_sum_vertical() {
let matrix = create_float_test_matrix();
// Col 0: 1.0 + 2.0 + 3.0 = 6.0
// Col 1: 4.0 + NaN + 6.0 = 10.0 (NaN ignored)
// Col 2: 7.0 + 8.0 + NaN = 15.0 (NaN ignored)
let expected = vec![6.0, 10.0, 15.0];
assert_eq!(matrix.sum_vertical(), expected);
}
#[test]
fn test_series_ops_sum_horizontal() {
let matrix = create_float_test_matrix();
// Row 0: 1.0 + 4.0 + 7.0 = 12.0
// Row 1: 2.0 + NaN + 8.0 = 10.0 (NaN ignored)
// Row 2: 3.0 + 6.0 + NaN = 9.0 (NaN ignored)
let expected = vec![12.0, 10.0, 9.0];
assert_eq!(matrix.sum_horizontal(), expected);
}
#[test]
fn test_series_ops_prod_vertical() {
let matrix = create_float_test_matrix();
// Col 0: 1.0 * 2.0 * 3.0 = 6.0
// Col 1: 4.0 * NaN * 6.0 = 24.0 (NaN ignored, starts with 1.0)
// Col 2: 7.0 * 8.0 * NaN = 56.0 (NaN ignored, starts with 1.0)
let expected = vec![6.0, 24.0, 56.0];
assert_eq!(matrix.prod_vertical(), expected);
}
#[test]
fn test_series_ops_prod_horizontal() {
let matrix = create_float_test_matrix();
// Row 0: 1.0 * 4.0 * 7.0 = 28.0
// Row 1: 2.0 * NaN * 8.0 = 16.0 (NaN ignored, starts with 1.0)
// Row 2: 3.0 * 6.0 * NaN = 18.0 (NaN ignored, starts with 1.0)
let expected = vec![28.0, 16.0, 18.0];
assert_eq!(matrix.prod_horizontal(), expected);
}
#[test]
fn test_series_ops_cumsum_vertical() {
let matrix = create_float_test_matrix();
// Col 0: [1.0, 1.0+2.0=3.0, 3.0+3.0=6.0]
// Col 1: [4.0, 4.0+NaN=4.0, 4.0+6.0=10.0] (NaN ignored, cumulative sum doesn't reset)
// Col 2: [7.0, 7.0+8.0=15.0, 15.0+NaN=15.0]
// Expected data (column-major): [1.0, 3.0, 6.0, 4.0, 4.0, 10.0, 7.0, 15.0, 15.0]
let expected_data = vec![1.0, 3.0, 6.0, 4.0, 4.0, 10.0, 7.0, 15.0, 15.0];
let expected_matrix = FloatMatrix::from_vec(expected_data, 3, 3);
assert_eq!(matrix.cumsum_vertical(), expected_matrix);
}
#[test]
fn test_series_ops_cumsum_horizontal() {
let matrix = create_float_test_matrix();
// Row 0: [1.0, 1.0+4.0=5.0, 5.0+7.0=12.0]
// Row 1: [2.0, 2.0+NaN=2.0, 2.0+8.0=10.0] (NaN ignored, cumulative sum doesn't reset)
// Row 2: [3.0, 3.0+6.0=9.0, 9.0+NaN=9.0]
// Expected data (column-major construction from row results):
// Col 0: (R0,C0)=1.0, (R1,C0)=2.0, (R2,C0)=3.0 => [1.0, 2.0, 3.0]
// Col 1: (R0,C1)=5.0, (R1,C1)=2.0, (R2,C1)=9.0 => [5.0, 2.0, 9.0]
// Col 2: (R0,C2)=12.0, (R1,C2)=10.0, (R2,C2)=9.0 => [12.0, 10.0, 9.0]
// Combined data: [1.0, 2.0, 3.0, 5.0, 2.0, 9.0, 12.0, 10.0, 9.0]
let expected_data = vec![1.0, 2.0, 3.0, 5.0, 2.0, 9.0, 12.0, 10.0, 9.0];
let expected_matrix = FloatMatrix::from_vec(expected_data, 3, 3);
assert_eq!(matrix.cumsum_horizontal(), expected_matrix);
}
#[test]
fn test_series_ops_count_nan_vertical() {
let matrix = create_float_test_matrix();
// Col 0: 0 NaNs
// Col 1: 1 NaN
// Col 2: 1 NaN
let expected = vec![0, 1, 1];
assert_eq!(matrix.count_nan_vertical(), expected);
}
#[test]
fn test_series_ops_count_nan_horizontal() {
let matrix = create_float_test_matrix();
// Row 0: 0 NaNs
// Row 1: 1 NaN
// Row 2: 1 NaN
let expected = vec![0, 1, 1];
assert_eq!(matrix.count_nan_horizontal(), expected);
}
#[test]
fn test_series_ops_is_nan() {
let matrix = create_float_test_matrix();
// Original data (col-major): [1.0, 2.0, 3.0, 4.0, NaN, 6.0, 7.0, 8.0, NaN]
// is_nan() applied: [F, F, F, F, T, F, F, F, T]
let expected_data = vec![false, false, false, false, true, false, false, false, true];
let expected_matrix = BoolMatrix::from_vec(expected_data, 3, 3);
assert_eq!(matrix.is_nan(), expected_matrix);
}
// --- Edge Cases for SeriesOps ---
#[test]
fn test_series_ops_1x1() {
let matrix = FloatMatrix::from_vec(vec![42.0], 1, 1);
assert_eq!(matrix.sum_vertical(), vec![42.0]);
assert_eq!(matrix.sum_horizontal(), vec![42.0]);
assert_eq!(matrix.prod_vertical(), vec![42.0]);
assert_eq!(matrix.prod_horizontal(), vec![42.0]);
assert_eq!(matrix.cumsum_vertical().data(), &[42.0]);
assert_eq!(matrix.cumsum_horizontal().data(), &[42.0]);
assert_eq!(matrix.count_nan_vertical(), vec![0]);
assert_eq!(matrix.count_nan_horizontal(), vec![0]);
assert_eq!(matrix.is_nan().data(), &[false]);
let matrix_nan = FloatMatrix::from_vec(vec![f64::NAN], 1, 1);
assert_eq!(matrix_nan.sum_vertical(), vec![0.0]); // sum of empty set is 0
assert_eq!(matrix_nan.sum_horizontal(), vec![0.0]);
assert_eq!(matrix_nan.prod_vertical(), vec![1.0]); // product of empty set is 1
assert_eq!(matrix_nan.prod_horizontal(), vec![1.0]);
assert_eq!(matrix_nan.cumsum_vertical().data(), &[0.0]); // cumsum starts at 0, nan ignored
assert_eq!(matrix_nan.cumsum_horizontal().data(), &[0.0]);
assert_eq!(matrix_nan.count_nan_vertical(), vec![1]);
assert_eq!(matrix_nan.count_nan_horizontal(), vec![1]);
assert_eq!(matrix_nan.is_nan().data(), &[true]);
}
#[test]
fn test_series_ops_1xn_matrix() {
let matrix = FloatMatrix::from_vec(vec![1.0, f64::NAN, 3.0, 4.0], 1, 4); // 1 row, 4 cols
// Data: [1.0, NaN, 3.0, 4.0]
// Vertical (sums/prods/counts per column - each col is just one element)
assert_eq!(matrix.sum_vertical(), vec![1.0, 0.0, 3.0, 4.0]); // NaN sum is 0
assert_eq!(matrix.prod_vertical(), vec![1.0, 1.0, 3.0, 4.0]); // NaN prod is 1
assert_eq!(matrix.count_nan_vertical(), vec![0, 1, 0, 0]);
assert_eq!(matrix.cumsum_vertical().data(), &[1.0, 0.0, 3.0, 4.0]); // Cumsum on single element column
// Horizontal (sums/prods/counts for the single row)
// Row 0: 1.0 + NaN + 3.0 + 4.0 = 8.0
// Row 0: 1.0 * NaN * 3.0 * 4.0 = 12.0
// Row 0: 1 NaN
assert_eq!(matrix.sum_horizontal(), vec![8.0]);
assert_eq!(matrix.prod_horizontal(), vec![12.0]);
assert_eq!(matrix.count_nan_horizontal(), vec![1]);
// Cumsum Horizontal
// Row 0: [1.0, 1.0+NaN=1.0, 1.0+3.0=4.0, 4.0+4.0=8.0]
// Data (col-major): [1.0, 1.0, 4.0, 8.0] (since it's 1 row, data is the same as the row result)
assert_eq!(matrix.cumsum_horizontal().data(), &[1.0, 1.0, 4.0, 8.0]);
// is_nan
// Data: [1.0, NaN, 3.0, 4.0]
// Expected: [F, T, F, F]
assert_eq!(matrix.is_nan().data(), &[false, true, false, false]);
}
#[test]
fn test_series_ops_nx1_matrix() {
let matrix = FloatMatrix::from_vec(vec![1.0, 2.0, f64::NAN, 4.0], 4, 1); // 4 rows, 1 col
// Data: [1.0, 2.0, NaN, 4.0]
// Vertical (sums/prods/counts for the single column)
// Col 0: 1.0 + 2.0 + NaN + 4.0 = 7.0
// Col 0: 1.0 * 2.0 * NaN * 4.0 = 8.0
// Col 0: 1 NaN
assert_eq!(matrix.sum_vertical(), vec![7.0]);
assert_eq!(matrix.prod_vertical(), vec![8.0]);
assert_eq!(matrix.count_nan_vertical(), vec![1]);
// Cumsum Vertical
// Col 0: [1.0, 1.0+2.0=3.0, 3.0+NaN=3.0, 3.0+4.0=7.0]
// Data (col-major): [1.0, 3.0, 3.0, 7.0] (since it's 1 col, data is the same as the col result)
assert_eq!(matrix.cumsum_vertical().data(), &[1.0, 3.0, 3.0, 7.0]);
// Horizontal (sums/prods/counts per row - each row is just one element)
assert_eq!(matrix.sum_horizontal(), vec![1.0, 2.0, 0.0, 4.0]); // NaN sum is 0
assert_eq!(matrix.prod_horizontal(), vec![1.0, 2.0, 1.0, 4.0]); // NaN prod is 1
assert_eq!(matrix.count_nan_horizontal(), vec![0, 0, 1, 0]);
assert_eq!(matrix.cumsum_horizontal().data(), &[1.0, 2.0, 0.0, 4.0]); // Cumsum on single element row
// is_nan
// Data: [1.0, 2.0, NaN, 4.0]
// Expected: [F, F, T, F]
assert_eq!(matrix.is_nan().data(), &[false, false, true, false]);
}
#[test]
fn test_series_ops_all_nan_matrix() {
let matrix = FloatMatrix::from_vec(vec![f64::NAN, f64::NAN, f64::NAN, f64::NAN], 2, 2);
// NaN NaN
// NaN NaN
// Data: [NaN, NaN, NaN, NaN]
assert_eq!(matrix.sum_vertical(), vec![0.0, 0.0]);
assert_eq!(matrix.sum_horizontal(), vec![0.0, 0.0]);
assert_eq!(matrix.prod_vertical(), vec![1.0, 1.0]);
assert_eq!(matrix.prod_horizontal(), vec![1.0, 1.0]);
assert_eq!(matrix.cumsum_vertical().data(), &[0.0, 0.0, 0.0, 0.0]);
assert_eq!(matrix.cumsum_horizontal().data(), &[0.0, 0.0, 0.0, 0.0]);
assert_eq!(matrix.count_nan_vertical(), vec![2, 2]);
assert_eq!(matrix.count_nan_horizontal(), vec![2, 2]);
assert_eq!(matrix.is_nan().data(), &[true, true, true, true]);
}
}