Implement map and zip methods for SeriesOps trait and add corresponding tests

This commit is contained in:
Palash Tyagi 2025-07-06 01:55:35 +01:00
parent 64e578fae2
commit 621632b7d2

View File

@ -12,6 +12,14 @@ pub trait SeriesOps {
where
F: FnMut(&[f64]) -> U;
fn map<F>(&self, f: F) -> FloatMatrix
where
F: Fn(f64) -> f64;
fn zip<F>(&self, other: &Self, f: F) -> FloatMatrix
where
F: Fn(f64, f64) -> f64;
fn matrix_mul(&self, other: &Self) -> FloatMatrix;
fn dot(&self, other: &Self) -> FloatMatrix;
@ -168,6 +176,42 @@ impl SeriesOps for FloatMatrix {
fn dot(&self, other: &Self) -> FloatMatrix {
self.matrix_mul(other)
}
fn map<F>(&self, f: F) -> FloatMatrix
where
F: Fn(f64) -> f64,
{
let data = self.data().iter().map(|&v| f(v)).collect::<Vec<_>>();
FloatMatrix::from_vec(data, self.rows(), self.cols())
}
fn zip<F>(&self, other: &Self, f: F) -> FloatMatrix
where
F: Fn(f64, f64) -> f64,
{
assert_eq!(
self.rows(),
other.rows(),
"Row count mismatch: {} vs {}",
self.rows(),
other.rows()
);
assert_eq!(
self.cols(),
other.cols(),
"Column count mismatch: {} vs {}",
self.cols(),
other.cols()
);
let data = self
.data()
.iter()
.zip(other.data().iter())
.map(|(&a, &b)| f(a, b))
.collect();
crate::matrix::Matrix::from_vec(data, self.rows(), self.cols())
}
}
#[cfg(test)]
@ -343,6 +387,39 @@ mod tests {
a.dot(&b); // This should panic due to dimension mismatch
}
#[test]
fn test_series_ops_map() {
let matrix = create_float_test_matrix();
// Map function to double each value
let mapped_matrix = matrix.map(|x| x * 2.0);
// Expected data after mapping
let expected_data = vec![2.0, 4.0, 6.0, 8.0, f64::NAN, 12.0, 14.0, 16.0, f64::NAN];
let expected_matrix = FloatMatrix::from_vec(expected_data, 3, 3);
// assert_eq!(mapped_matrix, expected_matrix);
for i in 0..mapped_matrix.data().len() {
// if not nan, check equality
if !mapped_matrix.data()[i].is_nan() {
assert_eq!(mapped_matrix.data()[i], expected_matrix.data()[i]);
} else {
assert!(mapped_matrix.data()[i].is_nan());
assert!(expected_matrix.data()[i].is_nan());
}
}
assert_eq!(mapped_matrix.rows(), expected_matrix.rows());
}
#[test]
fn test_series_ops_zip() {
let a = FloatMatrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2); // 2x2 matrix
let b = FloatMatrix::from_vec(vec![5.0, 6.0, 7.0, 8.0], 2, 2); // 2x2 matrix
// Zip function to add corresponding elements
let zipped_matrix = a.zip(&b, |x, y| x + y);
// Expected data after zipping
let expected_data = vec![6.0, 8.0, 10.0, 12.0];
let expected_matrix = FloatMatrix::from_vec(expected_data, 2, 2);
assert_eq!(zipped_matrix, expected_matrix);
}
// --- Edge Cases for SeriesOps ---
#[test]