diff --git a/README.md b/README.md index f31cb9c..93624e0 100644 --- a/README.md +++ b/README.md @@ -67,7 +67,7 @@ let fb = Frame::new(mb, col_names, Some(RowIndex::Date(dates))); // Math that reads like math let result = &fa * &fb; // element‑wise multiply -let total = result.matrix().sum_vertical().iter().sum::(); +let total = result.sum_vertical().iter().sum::(); assert_eq!(total, 184.0); let result = &ma + 1.0; // add scalar @@ -75,6 +75,12 @@ let result = &result - 1.0; // subtract scalar let result = &result * 2.0; // multiply by scalar let result = &result / 2.0; // divide by scalar -let check = result.eq_elementwise(ma).all(); +let check = result.eq_elementwise(ma.clone()).all(); +assert!(check); + +// The above math can also be written as: +let check = &(&(&(&(&ma + 1.0) - 1.0) * 2.0) / 2.0) + .eq_elementwise(ma) + .all(); assert!(check); ``` diff --git a/src/frame/mod.rs b/src/frame/mod.rs index 550fcea..e7840d4 100644 --- a/src/frame/mod.rs +++ b/src/frame/mod.rs @@ -1,3 +1,7 @@ pub mod base; +pub mod ops; -pub use base::*; \ No newline at end of file +pub use base::*; + +#[allow(unused_imports)] +pub use ops::*; \ No newline at end of file diff --git a/src/frame/ops.rs b/src/frame/ops.rs new file mode 100644 index 0000000..f52ee9d --- /dev/null +++ b/src/frame/ops.rs @@ -0,0 +1,154 @@ +use crate::frame::Frame; +use crate::matrix::{Axis, BoolMatrix, BoolOps, FloatMatrix, SeriesOps}; + +// Macro to delegate method calls to self.matrix() +macro_rules! delegate_to_matrix { + ($($method_name:ident -> $return_type:ty),* $(,)?) => { + $( + fn $method_name(&self) -> $return_type { + self.matrix().$method_name() + } + )* + }; +} + +impl SeriesOps for Frame { + #[allow(unused_mut)] + fn apply_axis(&self, axis: Axis, mut f: F) -> Vec + where + F: FnMut(&[f64]) -> U, + { + self.matrix().apply_axis(axis, f) + } + + delegate_to_matrix!( + sum_vertical -> Vec, + sum_horizontal -> Vec, + prod_horizontal -> Vec, + prod_vertical -> Vec, + cumsum_horizontal -> FloatMatrix, + cumsum_vertical -> FloatMatrix, + count_nan_vertical -> Vec, + count_nan_horizontal -> Vec, + is_nan -> BoolMatrix + ); +} + +impl BoolOps for Frame { + fn apply_axis(&self, axis: Axis, f: F) -> Vec + where + F: FnMut(&[bool]) -> U, + { + self.matrix().apply_axis(axis, f) + } + + delegate_to_matrix!( + any_vertical -> Vec, + any_horizontal -> Vec, + all_vertical -> Vec, + all_horizontal -> Vec, + count_vertical -> Vec, + count_horizontal -> Vec, + any -> bool, + all -> bool, + count -> usize + ); +} + +// use crate::frame::Frame; +// use crate::matrix::{Axis, SeriesOps, FloatMatrix, BoolMatrix}; + +// impl SeriesOps for Frame { +// fn apply_axis(&self, axis: Axis, mut f: F) -> Vec +// where +// F: FnMut(&[f64]) -> U, +// { +// self.matrix().apply_axis(axis, f) +// } + +// fn sum_vertical(&self) -> Vec { +// self.matrix().sum_vertical() +// } +// fn sum_horizontal(&self) -> Vec { +// self.matrix().sum_horizontal() +// } +// fn prod_horizontal(&self) -> Vec { +// self.matrix().prod_horizontal() +// } +// fn prod_vertical(&self) -> Vec { +// self.matrix().prod_vertical() +// } +// fn cumsum_horizontal(&self) -> FloatMatrix { +// self.matrix().cumsum_horizontal() +// } +// fn cumsum_vertical(&self) -> FloatMatrix { +// self.matrix().cumsum_vertical() +// } + +// fn count_nan_vertical(&self) -> Vec { +// self.matrix().count_nan_vertical() +// } +// fn count_nan_horizontal(&self) -> Vec { +// self.matrix().count_nan_horizontal() +// } +// fn is_nan(&self) -> BoolMatrix { +// self.matrix().is_nan() +// } +// } + +mod tests { + use crate::frame::*; + use crate::matrix::*; + + #[test] + fn test_series_ops() { + let col_names = vec!["A".to_string(), "B".to_string()]; + let frame = Frame::new( + Matrix::from_cols(vec![vec![1.0, 2.0], vec![3.0, 4.0]]), + col_names, + None, + ); + assert_eq!(frame.sum_vertical(), frame.matrix().sum_vertical()); + assert_eq!(frame.sum_horizontal(), frame.matrix().sum_horizontal()); + assert_eq!(frame.prod_horizontal(), frame.matrix().prod_horizontal()); + assert_eq!(frame.prod_vertical(), frame.matrix().prod_vertical()); + assert_eq!( + frame.cumsum_horizontal(), + frame.matrix().cumsum_horizontal() + ); + assert_eq!(frame.cumsum_vertical(), frame.matrix().cumsum_vertical()); + assert_eq!( + frame.count_nan_vertical(), + frame.matrix().count_nan_vertical() + ); + assert_eq!( + frame.count_nan_horizontal(), + frame.matrix().count_nan_horizontal() + ); + assert_eq!(frame.is_nan(), frame.matrix().is_nan()); + assert_eq!(frame.apply_axis(Axis::Row, |x| x[0] + x[1]), vec![4.0, 6.0]); + } + #[test] + + fn test_bool_ops() { + let col_names = vec!["A".to_string(), "B".to_string()]; + let frame = Frame::new( + Matrix::from_cols(vec![vec![true, false], vec![false, true]]), + col_names, + None, + ); + assert_eq!(frame.any_vertical(), vec![true, true]); + assert_eq!(frame.any_horizontal(), vec![true, true]); + assert_eq!(frame.all_horizontal(), vec![false, false]); + assert_eq!(frame.all_vertical(), vec![false, false]); + assert_eq!(frame.count_vertical(), vec![1, 1]); + assert_eq!(frame.count_horizontal(), vec![1, 1]); + assert_eq!(frame.any(), true); + assert_eq!(frame.all(), false); + assert_eq!(frame.count(), 2); + assert_eq!( + frame.apply_axis(Axis::Row, |x| x[0] && x[1]), + vec![false, false] + ); + } +}