mirror of
https://github.com/Magnus167/rustframe.git
synced 2025-08-20 04:00:01 +00:00
Implement map and zip methods for SeriesOps trait and add corresponding tests
This commit is contained in:
parent
64e578fae2
commit
621632b7d2
@ -12,6 +12,14 @@ pub trait SeriesOps {
|
|||||||
where
|
where
|
||||||
F: FnMut(&[f64]) -> U;
|
F: FnMut(&[f64]) -> U;
|
||||||
|
|
||||||
|
fn map<F>(&self, f: F) -> FloatMatrix
|
||||||
|
where
|
||||||
|
F: Fn(f64) -> f64;
|
||||||
|
|
||||||
|
fn zip<F>(&self, other: &Self, f: F) -> FloatMatrix
|
||||||
|
where
|
||||||
|
F: Fn(f64, f64) -> f64;
|
||||||
|
|
||||||
fn matrix_mul(&self, other: &Self) -> FloatMatrix;
|
fn matrix_mul(&self, other: &Self) -> FloatMatrix;
|
||||||
fn dot(&self, other: &Self) -> FloatMatrix;
|
fn dot(&self, other: &Self) -> FloatMatrix;
|
||||||
|
|
||||||
@ -168,6 +176,42 @@ impl SeriesOps for FloatMatrix {
|
|||||||
fn dot(&self, other: &Self) -> FloatMatrix {
|
fn dot(&self, other: &Self) -> FloatMatrix {
|
||||||
self.matrix_mul(other)
|
self.matrix_mul(other)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn map<F>(&self, f: F) -> FloatMatrix
|
||||||
|
where
|
||||||
|
F: Fn(f64) -> f64,
|
||||||
|
{
|
||||||
|
let data = self.data().iter().map(|&v| f(v)).collect::<Vec<_>>();
|
||||||
|
FloatMatrix::from_vec(data, self.rows(), self.cols())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn zip<F>(&self, other: &Self, f: F) -> FloatMatrix
|
||||||
|
where
|
||||||
|
F: Fn(f64, f64) -> f64,
|
||||||
|
{
|
||||||
|
assert_eq!(
|
||||||
|
self.rows(),
|
||||||
|
other.rows(),
|
||||||
|
"Row count mismatch: {} vs {}",
|
||||||
|
self.rows(),
|
||||||
|
other.rows()
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
self.cols(),
|
||||||
|
other.cols(),
|
||||||
|
"Column count mismatch: {} vs {}",
|
||||||
|
self.cols(),
|
||||||
|
other.cols()
|
||||||
|
);
|
||||||
|
|
||||||
|
let data = self
|
||||||
|
.data()
|
||||||
|
.iter()
|
||||||
|
.zip(other.data().iter())
|
||||||
|
.map(|(&a, &b)| f(a, b))
|
||||||
|
.collect();
|
||||||
|
crate::matrix::Matrix::from_vec(data, self.rows(), self.cols())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@ -343,6 +387,39 @@ mod tests {
|
|||||||
a.dot(&b); // This should panic due to dimension mismatch
|
a.dot(&b); // This should panic due to dimension mismatch
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_series_ops_map() {
|
||||||
|
let matrix = create_float_test_matrix();
|
||||||
|
// Map function to double each value
|
||||||
|
let mapped_matrix = matrix.map(|x| x * 2.0);
|
||||||
|
// Expected data after mapping
|
||||||
|
let expected_data = vec![2.0, 4.0, 6.0, 8.0, f64::NAN, 12.0, 14.0, 16.0, f64::NAN];
|
||||||
|
let expected_matrix = FloatMatrix::from_vec(expected_data, 3, 3);
|
||||||
|
// assert_eq!(mapped_matrix, expected_matrix);
|
||||||
|
for i in 0..mapped_matrix.data().len() {
|
||||||
|
// if not nan, check equality
|
||||||
|
if !mapped_matrix.data()[i].is_nan() {
|
||||||
|
assert_eq!(mapped_matrix.data()[i], expected_matrix.data()[i]);
|
||||||
|
} else {
|
||||||
|
assert!(mapped_matrix.data()[i].is_nan());
|
||||||
|
assert!(expected_matrix.data()[i].is_nan());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert_eq!(mapped_matrix.rows(), expected_matrix.rows());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_series_ops_zip() {
|
||||||
|
let a = FloatMatrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2); // 2x2 matrix
|
||||||
|
let b = FloatMatrix::from_vec(vec![5.0, 6.0, 7.0, 8.0], 2, 2); // 2x2 matrix
|
||||||
|
// Zip function to add corresponding elements
|
||||||
|
let zipped_matrix = a.zip(&b, |x, y| x + y);
|
||||||
|
// Expected data after zipping
|
||||||
|
let expected_data = vec![6.0, 8.0, 10.0, 12.0];
|
||||||
|
let expected_matrix = FloatMatrix::from_vec(expected_data, 2, 2);
|
||||||
|
assert_eq!(zipped_matrix, expected_matrix);
|
||||||
|
}
|
||||||
|
|
||||||
// --- Edge Cases for SeriesOps ---
|
// --- Edge Cases for SeriesOps ---
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user