diff --git a/src/matrix/seriesops.rs b/src/matrix/seriesops.rs index 65702e9..0c816e1 100644 --- a/src/matrix/seriesops.rs +++ b/src/matrix/seriesops.rs @@ -12,6 +12,14 @@ pub trait SeriesOps { where F: FnMut(&[f64]) -> U; + fn map(&self, f: F) -> FloatMatrix + where + F: Fn(f64) -> f64; + + fn zip(&self, other: &Self, f: F) -> FloatMatrix + where + F: Fn(f64, f64) -> f64; + fn matrix_mul(&self, other: &Self) -> FloatMatrix; fn dot(&self, other: &Self) -> FloatMatrix; @@ -168,6 +176,42 @@ impl SeriesOps for FloatMatrix { fn dot(&self, other: &Self) -> FloatMatrix { self.matrix_mul(other) } + + fn map(&self, f: F) -> FloatMatrix + where + F: Fn(f64) -> f64, + { + let data = self.data().iter().map(|&v| f(v)).collect::>(); + FloatMatrix::from_vec(data, self.rows(), self.cols()) + } + + fn zip(&self, other: &Self, f: F) -> FloatMatrix + where + F: Fn(f64, f64) -> f64, + { + assert_eq!( + self.rows(), + other.rows(), + "Row count mismatch: {} vs {}", + self.rows(), + other.rows() + ); + assert_eq!( + self.cols(), + other.cols(), + "Column count mismatch: {} vs {}", + self.cols(), + other.cols() + ); + + let data = self + .data() + .iter() + .zip(other.data().iter()) + .map(|(&a, &b)| f(a, b)) + .collect(); + crate::matrix::Matrix::from_vec(data, self.rows(), self.cols()) + } } #[cfg(test)] @@ -343,6 +387,39 @@ mod tests { a.dot(&b); // This should panic due to dimension mismatch } + #[test] + fn test_series_ops_map() { + let matrix = create_float_test_matrix(); + // Map function to double each value + let mapped_matrix = matrix.map(|x| x * 2.0); + // Expected data after mapping + let expected_data = vec![2.0, 4.0, 6.0, 8.0, f64::NAN, 12.0, 14.0, 16.0, f64::NAN]; + let expected_matrix = FloatMatrix::from_vec(expected_data, 3, 3); + // assert_eq!(mapped_matrix, expected_matrix); + for i in 0..mapped_matrix.data().len() { + // if not nan, check equality + if !mapped_matrix.data()[i].is_nan() { + assert_eq!(mapped_matrix.data()[i], expected_matrix.data()[i]); + } else { + assert!(mapped_matrix.data()[i].is_nan()); + assert!(expected_matrix.data()[i].is_nan()); + } + } + assert_eq!(mapped_matrix.rows(), expected_matrix.rows()); + } + + #[test] + fn test_series_ops_zip() { + let a = FloatMatrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2); // 2x2 matrix + let b = FloatMatrix::from_vec(vec![5.0, 6.0, 7.0, 8.0], 2, 2); // 2x2 matrix + // Zip function to add corresponding elements + let zipped_matrix = a.zip(&b, |x, y| x + y); + // Expected data after zipping + let expected_data = vec![6.0, 8.0, 10.0, 12.0]; + let expected_matrix = FloatMatrix::from_vec(expected_data, 2, 2); + assert_eq!(zipped_matrix, expected_matrix); + } + // --- Edge Cases for SeriesOps --- #[test]