mirror of
https://github.com/Magnus167/rustframe.git
synced 2025-08-20 04:00:01 +00:00
Implement map and zip methods for SeriesOps trait and add corresponding tests
This commit is contained in:
parent
64e578fae2
commit
621632b7d2
@ -12,6 +12,14 @@ pub trait SeriesOps {
|
||||
where
|
||||
F: FnMut(&[f64]) -> U;
|
||||
|
||||
fn map<F>(&self, f: F) -> FloatMatrix
|
||||
where
|
||||
F: Fn(f64) -> f64;
|
||||
|
||||
fn zip<F>(&self, other: &Self, f: F) -> FloatMatrix
|
||||
where
|
||||
F: Fn(f64, f64) -> f64;
|
||||
|
||||
fn matrix_mul(&self, other: &Self) -> FloatMatrix;
|
||||
fn dot(&self, other: &Self) -> FloatMatrix;
|
||||
|
||||
@ -168,6 +176,42 @@ impl SeriesOps for FloatMatrix {
|
||||
fn dot(&self, other: &Self) -> FloatMatrix {
|
||||
self.matrix_mul(other)
|
||||
}
|
||||
|
||||
fn map<F>(&self, f: F) -> FloatMatrix
|
||||
where
|
||||
F: Fn(f64) -> f64,
|
||||
{
|
||||
let data = self.data().iter().map(|&v| f(v)).collect::<Vec<_>>();
|
||||
FloatMatrix::from_vec(data, self.rows(), self.cols())
|
||||
}
|
||||
|
||||
fn zip<F>(&self, other: &Self, f: F) -> FloatMatrix
|
||||
where
|
||||
F: Fn(f64, f64) -> f64,
|
||||
{
|
||||
assert_eq!(
|
||||
self.rows(),
|
||||
other.rows(),
|
||||
"Row count mismatch: {} vs {}",
|
||||
self.rows(),
|
||||
other.rows()
|
||||
);
|
||||
assert_eq!(
|
||||
self.cols(),
|
||||
other.cols(),
|
||||
"Column count mismatch: {} vs {}",
|
||||
self.cols(),
|
||||
other.cols()
|
||||
);
|
||||
|
||||
let data = self
|
||||
.data()
|
||||
.iter()
|
||||
.zip(other.data().iter())
|
||||
.map(|(&a, &b)| f(a, b))
|
||||
.collect();
|
||||
crate::matrix::Matrix::from_vec(data, self.rows(), self.cols())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@ -343,6 +387,39 @@ mod tests {
|
||||
a.dot(&b); // This should panic due to dimension mismatch
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_series_ops_map() {
|
||||
let matrix = create_float_test_matrix();
|
||||
// Map function to double each value
|
||||
let mapped_matrix = matrix.map(|x| x * 2.0);
|
||||
// Expected data after mapping
|
||||
let expected_data = vec![2.0, 4.0, 6.0, 8.0, f64::NAN, 12.0, 14.0, 16.0, f64::NAN];
|
||||
let expected_matrix = FloatMatrix::from_vec(expected_data, 3, 3);
|
||||
// assert_eq!(mapped_matrix, expected_matrix);
|
||||
for i in 0..mapped_matrix.data().len() {
|
||||
// if not nan, check equality
|
||||
if !mapped_matrix.data()[i].is_nan() {
|
||||
assert_eq!(mapped_matrix.data()[i], expected_matrix.data()[i]);
|
||||
} else {
|
||||
assert!(mapped_matrix.data()[i].is_nan());
|
||||
assert!(expected_matrix.data()[i].is_nan());
|
||||
}
|
||||
}
|
||||
assert_eq!(mapped_matrix.rows(), expected_matrix.rows());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_series_ops_zip() {
|
||||
let a = FloatMatrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2); // 2x2 matrix
|
||||
let b = FloatMatrix::from_vec(vec![5.0, 6.0, 7.0, 8.0], 2, 2); // 2x2 matrix
|
||||
// Zip function to add corresponding elements
|
||||
let zipped_matrix = a.zip(&b, |x, y| x + y);
|
||||
// Expected data after zipping
|
||||
let expected_data = vec![6.0, 8.0, 10.0, 12.0];
|
||||
let expected_matrix = FloatMatrix::from_vec(expected_data, 2, 2);
|
||||
assert_eq!(zipped_matrix, expected_matrix);
|
||||
}
|
||||
|
||||
// --- Edge Cases for SeriesOps ---
|
||||
|
||||
#[test]
|
||||
|
Loading…
x
Reference in New Issue
Block a user