diff --git a/README.md b/README.md index 6bd9927..a316f3e 100644 --- a/README.md +++ b/README.md @@ -94,10 +94,38 @@ assert!(check); // The above math can also be written as: let check: bool = ((((ma.clone() + 1.0) - 1.0) * 2.0) / 2.0) - .eq_elem(ma) + .eq_elem(ma.clone()) .all(); assert!(check); +// Matrix multiplication +let mc: Matrix = Matrix::from_cols(vec![vec![1.0, 2.0], vec![3.0, 4.0]]); +let md: Matrix = Matrix::from_cols(vec![vec![5.0, 6.0], vec![7.0, 8.0]]); +let mul_result: Matrix = mc.matrix_mul(&md); +// Expected: +// 1*5 + 3*6 = 5 + 18 = 23 +// 2*5 + 4*6 = 10 + 24 = 34 +// 1*7 + 3*8 = 7 + 24 = 31 +// 2*7 + 4*8 = 14 + 32 = 46 +assert_eq!(mul_result.data(), &[23.0, 34.0, 31.0, 46.0]); + +// Dot product (alias for matrix_mul for FloatMatrix) +let dot_result: Matrix = mc.dot(&md); +assert_eq!(dot_result, mul_result); + +// Transpose +let original_matrix: Matrix = Matrix::from_cols(vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]]); +// Original: +// 1 4 +// 2 5 +// 3 6 +let transposed_matrix: Matrix = original_matrix.transpose(); +// Transposed: +// 1 2 3 +// 4 5 6 +assert_eq!(transposed_matrix.rows(), 2); +assert_eq!(transposed_matrix.cols(), 3); +assert_eq!(transposed_matrix.data(), &[1.0, 4.0, 2.0, 5.0, 3.0, 6.0]); ``` diff --git a/src/frame/base.rs b/src/frame/base.rs index 4b54436..553062c 100644 --- a/src/frame/base.rs +++ b/src/frame/base.rs @@ -483,7 +483,6 @@ impl Frame { deleted_data } - /// Returns a new `Matrix` that is the transpose of the current frame's matrix. pub fn transpose(&self) -> Matrix { self.matrix.transpose() @@ -525,6 +524,45 @@ impl Frame { } } + pub fn frame_map(&self, f: impl Fn(&T) -> T) -> Frame { + Frame::new( + Matrix::from_vec( + self.matrix.data().iter().map(f).collect(), + self.matrix.rows(), + self.matrix.cols(), + ), + self.column_names.clone(), + Some(self.index.clone()), + ) + } + + pub fn frame_zip(&self, other: &Frame, f: impl Fn(&T, &T) -> T) -> Frame { + if self.rows() != other.rows() || self.cols() != other.cols() { + panic!( + "Frame::frame_zip: incompatible dimensions (self: {}x{}, other: {}x{})", + self.rows(), + self.cols(), + other.rows(), + other.cols() + ); + } + + Frame::new( + Matrix::from_vec( + self.matrix + .data() + .iter() + .zip(other.matrix.data()) + .map(|(a, b)| f(a, b)) + .collect(), + self.rows(), + self.cols(), + ), + self.column_names.clone(), + Some(self.index.clone()), + ) + } + // Internal helpers /// Rebuilds the column lookup map to match the current `column_names` ordering. @@ -1666,6 +1704,45 @@ mod tests { assert_eq!(frame1.columns(), &["Z"]); } + #[test] + fn test_frame_map() { + let frame = create_test_frame_f64(); // A=[1,2,3], B=[4,5,6] + let mapped_frame = frame.frame_map(|x| x * 2.0); // Multiply each value by 2.0 + assert_eq!(mapped_frame.columns(), frame.columns()); + assert_eq!(mapped_frame.index(), frame.index()); + assert!((mapped_frame["A"][0] - 2.0).abs() < FLOAT_TOLERANCE); + assert!((mapped_frame["A"][1] - 4.0).abs() < FLOAT_TOLERANCE); + assert!((mapped_frame["A"][2] - 6.0).abs() < FLOAT_TOLERANCE); + assert!((mapped_frame["B"][0] - 8.0).abs() < FLOAT_TOLERANCE); + assert!((mapped_frame["B"][1] - 10.0).abs() < FLOAT_TOLERANCE); + assert!((mapped_frame["B"][2] - 12.0).abs() < FLOAT_TOLERANCE); + } + + #[test] + fn test_frame_zip() { + let f1 = create_test_frame_f64(); // A=[1,2,3], B=[4,5,6] + let f2 = create_test_frame_f64_alt(); // A=[0.1,0.2,0.3], B=[0.4,0.5,0.6] + let zipped_frame = f1.frame_zip(&f2, |x, y| x + y); // Element-wise addition + assert_eq!(zipped_frame.columns(), f1.columns()); + assert_eq!(zipped_frame.index(), f1.index()); + assert!((zipped_frame["A"][0] - 1.1).abs() < FLOAT_TOLERANCE); + assert!((zipped_frame["A"][1] - 2.2).abs() < FLOAT_TOLERANCE); + assert!((zipped_frame["A"][2] - 3.3).abs() < FLOAT_TOLERANCE); + assert!((zipped_frame["B"][0] - 4.4).abs() < FLOAT_TOLERANCE); + assert!((zipped_frame["B"][1] - 5.5).abs() < FLOAT_TOLERANCE); + assert!((zipped_frame["B"][2] - 6.6).abs() < FLOAT_TOLERANCE); + } + + #[test] + #[should_panic(expected = "Frame::frame_zip: incompatible dimensions (self: 3x1, other: 3x2)")] + fn test_frame_zip_panic() { + let mut f1 = create_test_frame_f64(); + let f2 = create_test_frame_f64_alt(); + f1.delete_column("B"); + + f1.frame_zip(&f2, |x, y| x + y); // Should panic due to different column counts + } + // --- Element-wise Arithmetic Ops Tests --- #[test] fn test_frame_arithmetic_ops_f64() { diff --git a/src/frame/ops.rs b/src/frame/ops.rs index dea2396..1c3254b 100644 --- a/src/frame/ops.rs +++ b/src/frame/ops.rs @@ -20,6 +20,21 @@ impl SeriesOps for Frame { { self.matrix().apply_axis(axis, f) } + + fn map(&self, f: F) -> FloatMatrix + where + F: Fn(f64) -> f64, + { + self.matrix().map(f) + } + + fn zip(&self, other: &Self, f: F) -> FloatMatrix + where + F: Fn(f64, f64) -> f64, + { + self.matrix().zip(other.matrix(), f) + } + fn matrix_mul(&self, other: &Self) -> FloatMatrix { self.matrix().matrix_mul(other.matrix()) } @@ -113,7 +128,7 @@ mod tests { let col_names = vec!["A".to_string(), "B".to_string()]; let frame = Frame::new( Matrix::from_cols(vec![vec![1.0, 2.0], vec![3.0, 4.0]]), - col_names, + col_names.clone(), None, ); assert_eq!(frame.sum_vertical(), frame.matrix().sum_vertical()); @@ -136,16 +151,32 @@ mod tests { assert_eq!(frame.is_nan(), frame.matrix().is_nan()); assert_eq!(frame.apply_axis(Axis::Row, |x| x[0] + x[1]), vec![4.0, 6.0]); - - assert_eq!(frame.matrix_mul(&frame), frame.matrix().matrix_mul(&frame.matrix())); + assert_eq!( + frame.matrix_mul(&frame), + frame.matrix().matrix_mul(&frame.matrix()) + ); assert_eq!(frame.dot(&frame), frame.matrix().dot(&frame.matrix())); - + // test transpose - returns a matrix. let frame_transposed_mat = frame.transpose(); let frame_mat_transposed = frame.matrix().transpose(); assert_eq!(frame_transposed_mat, frame_mat_transposed); assert_eq!(frame.matrix(), &frame.matrix().transpose().transpose()); + // test map + let mapped_frame = frame.map(|x| x * 2.0); + let expected_matrix = frame.matrix().map(|x| x * 2.0); + assert_eq!(mapped_frame, expected_matrix); + + // test zip + let other_frame = Frame::new( + Matrix::from_cols(vec![vec![5.0, 6.0], vec![7.0, 8.0]]), + col_names.clone(), + None, + ); + let zipped_frame = frame.zip(&other_frame, |x, y| x + y); + let expected_zipped_matrix = frame.matrix().zip(other_frame.matrix(), |x, y| x + y); + assert_eq!(zipped_frame, expected_zipped_matrix); } #[test] @@ -170,7 +201,4 @@ mod tests { vec![false, false] ); } - - - } diff --git a/src/matrix/seriesops.rs b/src/matrix/seriesops.rs index 65702e9..7072130 100644 --- a/src/matrix/seriesops.rs +++ b/src/matrix/seriesops.rs @@ -12,6 +12,14 @@ pub trait SeriesOps { where F: FnMut(&[f64]) -> U; + fn map(&self, f: F) -> FloatMatrix + where + F: Fn(f64) -> f64; + + fn zip(&self, other: &Self, f: F) -> FloatMatrix + where + F: Fn(f64, f64) -> f64; + fn matrix_mul(&self, other: &Self) -> FloatMatrix; fn dot(&self, other: &Self) -> FloatMatrix; @@ -168,6 +176,36 @@ impl SeriesOps for FloatMatrix { fn dot(&self, other: &Self) -> FloatMatrix { self.matrix_mul(other) } + + fn map(&self, f: F) -> FloatMatrix + where + F: Fn(f64) -> f64, + { + let data = self.data().iter().map(|&v| f(v)).collect::>(); + FloatMatrix::from_vec(data, self.rows(), self.cols()) + } + + fn zip(&self, other: &Self, f: F) -> FloatMatrix + where + F: Fn(f64, f64) -> f64, + { + assert!( + self.rows() == other.rows() && self.cols() == other.cols(), + "Matrix dimensions mismatch: left is {}x{}, right is {}x{}", + self.rows(), + self.cols(), + other.rows(), + other.cols() + ); + + let data = self + .data() + .iter() + .zip(other.data().iter()) + .map(|(&a, &b)| f(a, b)) + .collect(); + crate::matrix::Matrix::from_vec(data, self.rows(), self.cols()) + } } #[cfg(test)] @@ -343,6 +381,48 @@ mod tests { a.dot(&b); // This should panic due to dimension mismatch } + #[test] + fn test_series_ops_map() { + let matrix = create_float_test_matrix(); + // Map function to double each value + let mapped_matrix = matrix.map(|x| x * 2.0); + // Expected data after mapping + let expected_data = vec![2.0, 4.0, 6.0, 8.0, f64::NAN, 12.0, 14.0, 16.0, f64::NAN]; + let expected_matrix = FloatMatrix::from_vec(expected_data, 3, 3); + // assert_eq!(mapped_matrix, expected_matrix); + for i in 0..mapped_matrix.data().len() { + // if not nan, check equality + if !mapped_matrix.data()[i].is_nan() { + assert_eq!(mapped_matrix.data()[i], expected_matrix.data()[i]); + } else { + assert!(mapped_matrix.data()[i].is_nan()); + assert!(expected_matrix.data()[i].is_nan()); + } + } + assert_eq!(mapped_matrix.rows(), expected_matrix.rows()); + } + + #[test] + fn test_series_ops_zip() { + let a = FloatMatrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2); // 2x2 matrix + let b = FloatMatrix::from_vec(vec![5.0, 6.0, 7.0, 8.0], 2, 2); // 2x2 matrix + // Zip function to add corresponding elements + let zipped_matrix = a.zip(&b, |x, y| x + y); + // Expected data after zipping + let expected_data = vec![6.0, 8.0, 10.0, 12.0]; + let expected_matrix = FloatMatrix::from_vec(expected_data, 2, 2); + assert_eq!(zipped_matrix, expected_matrix); + } + + #[test] + #[should_panic(expected = "Matrix dimensions mismatch: left is 2x2, right is 3x2")] + fn test_series_ops_zip_panic() { + let a = FloatMatrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2); // 2x2 matrix + let b = FloatMatrix::from_vec(vec![5.0, 6.0, 7.0, 8.0, 9.0, 10.0], 3, 2); // 3x2 matrix + // This should panic due to dimension mismatch + a.zip(&b, |x, y| x + y); + } + // --- Edge Cases for SeriesOps --- #[test]