mirror of
https://github.com/Magnus167/rustframe.git
synced 2025-08-20 04:00:01 +00:00
Merge pull request #55 from Magnus167/map-zip
Implement `map` and `zip` methods for SeriesOps and Frame
This commit is contained in:
commit
5db5475a61
@ -471,7 +471,6 @@ impl<T: Clone + PartialEq> Frame<T> {
|
||||
deleted_data
|
||||
}
|
||||
|
||||
|
||||
/// Returns a new `Matrix` that is the transpose of the current frame's matrix.
|
||||
pub fn transpose(&self) -> Matrix<T> {
|
||||
self.matrix.transpose()
|
||||
@ -513,6 +512,45 @@ impl<T: Clone + PartialEq> Frame<T> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn frame_map(&self, f: impl Fn(&T) -> T) -> Frame<T> {
|
||||
Frame::new(
|
||||
Matrix::from_vec(
|
||||
self.matrix.data().iter().map(f).collect(),
|
||||
self.matrix.rows(),
|
||||
self.matrix.cols(),
|
||||
),
|
||||
self.column_names.clone(),
|
||||
Some(self.index.clone()),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn frame_zip(&self, other: &Frame<T>, f: impl Fn(&T, &T) -> T) -> Frame<T> {
|
||||
if self.rows() != other.rows() || self.cols() != other.cols() {
|
||||
panic!(
|
||||
"Frame::frame_zip: incompatible dimensions (self: {}x{}, other: {}x{})",
|
||||
self.rows(),
|
||||
self.cols(),
|
||||
other.rows(),
|
||||
other.cols()
|
||||
);
|
||||
}
|
||||
|
||||
Frame::new(
|
||||
Matrix::from_vec(
|
||||
self.matrix
|
||||
.data()
|
||||
.iter()
|
||||
.zip(other.matrix.data())
|
||||
.map(|(a, b)| f(a, b))
|
||||
.collect(),
|
||||
self.rows(),
|
||||
self.cols(),
|
||||
),
|
||||
self.column_names.clone(),
|
||||
Some(self.index.clone()),
|
||||
)
|
||||
}
|
||||
|
||||
// Internal helpers
|
||||
|
||||
/// Rebuilds the column lookup map to match the current `column_names` ordering.
|
||||
@ -1654,6 +1692,45 @@ mod tests {
|
||||
assert_eq!(frame1.columns(), &["Z"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_frame_map() {
|
||||
let frame = create_test_frame_f64(); // A=[1,2,3], B=[4,5,6]
|
||||
let mapped_frame = frame.frame_map(|x| x * 2.0); // Multiply each value by 2.0
|
||||
assert_eq!(mapped_frame.columns(), frame.columns());
|
||||
assert_eq!(mapped_frame.index(), frame.index());
|
||||
assert!((mapped_frame["A"][0] - 2.0).abs() < FLOAT_TOLERANCE);
|
||||
assert!((mapped_frame["A"][1] - 4.0).abs() < FLOAT_TOLERANCE);
|
||||
assert!((mapped_frame["A"][2] - 6.0).abs() < FLOAT_TOLERANCE);
|
||||
assert!((mapped_frame["B"][0] - 8.0).abs() < FLOAT_TOLERANCE);
|
||||
assert!((mapped_frame["B"][1] - 10.0).abs() < FLOAT_TOLERANCE);
|
||||
assert!((mapped_frame["B"][2] - 12.0).abs() < FLOAT_TOLERANCE);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_frame_zip() {
|
||||
let f1 = create_test_frame_f64(); // A=[1,2,3], B=[4,5,6]
|
||||
let f2 = create_test_frame_f64_alt(); // A=[0.1,0.2,0.3], B=[0.4,0.5,0.6]
|
||||
let zipped_frame = f1.frame_zip(&f2, |x, y| x + y); // Element-wise addition
|
||||
assert_eq!(zipped_frame.columns(), f1.columns());
|
||||
assert_eq!(zipped_frame.index(), f1.index());
|
||||
assert!((zipped_frame["A"][0] - 1.1).abs() < FLOAT_TOLERANCE);
|
||||
assert!((zipped_frame["A"][1] - 2.2).abs() < FLOAT_TOLERANCE);
|
||||
assert!((zipped_frame["A"][2] - 3.3).abs() < FLOAT_TOLERANCE);
|
||||
assert!((zipped_frame["B"][0] - 4.4).abs() < FLOAT_TOLERANCE);
|
||||
assert!((zipped_frame["B"][1] - 5.5).abs() < FLOAT_TOLERANCE);
|
||||
assert!((zipped_frame["B"][2] - 6.6).abs() < FLOAT_TOLERANCE);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Frame::frame_zip: incompatible dimensions (self: 3x1, other: 3x2)")]
|
||||
fn test_frame_zip_panic() {
|
||||
let mut f1 = create_test_frame_f64();
|
||||
let f2 = create_test_frame_f64_alt();
|
||||
f1.delete_column("B");
|
||||
|
||||
f1.frame_zip(&f2, |x, y| x + y); // Should panic due to different column counts
|
||||
}
|
||||
|
||||
// --- Element-wise Arithmetic Ops Tests ---
|
||||
#[test]
|
||||
fn test_frame_arithmetic_ops_f64() {
|
||||
|
@ -20,6 +20,21 @@ impl SeriesOps for Frame<f64> {
|
||||
{
|
||||
self.matrix().apply_axis(axis, f)
|
||||
}
|
||||
|
||||
fn map<F>(&self, f: F) -> FloatMatrix
|
||||
where
|
||||
F: Fn(f64) -> f64,
|
||||
{
|
||||
self.matrix().map(f)
|
||||
}
|
||||
|
||||
fn zip<F>(&self, other: &Self, f: F) -> FloatMatrix
|
||||
where
|
||||
F: Fn(f64, f64) -> f64,
|
||||
{
|
||||
self.matrix().zip(other.matrix(), f)
|
||||
}
|
||||
|
||||
fn matrix_mul(&self, other: &Self) -> FloatMatrix {
|
||||
self.matrix().matrix_mul(other.matrix())
|
||||
}
|
||||
@ -113,7 +128,7 @@ mod tests {
|
||||
let col_names = vec!["A".to_string(), "B".to_string()];
|
||||
let frame = Frame::new(
|
||||
Matrix::from_cols(vec![vec![1.0, 2.0], vec![3.0, 4.0]]),
|
||||
col_names,
|
||||
col_names.clone(),
|
||||
None,
|
||||
);
|
||||
assert_eq!(frame.sum_vertical(), frame.matrix().sum_vertical());
|
||||
@ -136,8 +151,10 @@ mod tests {
|
||||
assert_eq!(frame.is_nan(), frame.matrix().is_nan());
|
||||
assert_eq!(frame.apply_axis(Axis::Row, |x| x[0] + x[1]), vec![4.0, 6.0]);
|
||||
|
||||
|
||||
assert_eq!(frame.matrix_mul(&frame), frame.matrix().matrix_mul(&frame.matrix()));
|
||||
assert_eq!(
|
||||
frame.matrix_mul(&frame),
|
||||
frame.matrix().matrix_mul(&frame.matrix())
|
||||
);
|
||||
assert_eq!(frame.dot(&frame), frame.matrix().dot(&frame.matrix()));
|
||||
|
||||
// test transpose - returns a matrix.
|
||||
@ -146,6 +163,20 @@ mod tests {
|
||||
assert_eq!(frame_transposed_mat, frame_mat_transposed);
|
||||
assert_eq!(frame.matrix(), &frame.matrix().transpose().transpose());
|
||||
|
||||
// test map
|
||||
let mapped_frame = frame.map(|x| x * 2.0);
|
||||
let expected_matrix = frame.matrix().map(|x| x * 2.0);
|
||||
assert_eq!(mapped_frame, expected_matrix);
|
||||
|
||||
// test zip
|
||||
let other_frame = Frame::new(
|
||||
Matrix::from_cols(vec![vec![5.0, 6.0], vec![7.0, 8.0]]),
|
||||
col_names.clone(),
|
||||
None,
|
||||
);
|
||||
let zipped_frame = frame.zip(&other_frame, |x, y| x + y);
|
||||
let expected_zipped_matrix = frame.matrix().zip(other_frame.matrix(), |x, y| x + y);
|
||||
assert_eq!(zipped_frame, expected_zipped_matrix);
|
||||
}
|
||||
#[test]
|
||||
|
||||
@ -170,7 +201,4 @@ mod tests {
|
||||
vec![false, false]
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
@ -12,6 +12,14 @@ pub trait SeriesOps {
|
||||
where
|
||||
F: FnMut(&[f64]) -> U;
|
||||
|
||||
fn map<F>(&self, f: F) -> FloatMatrix
|
||||
where
|
||||
F: Fn(f64) -> f64;
|
||||
|
||||
fn zip<F>(&self, other: &Self, f: F) -> FloatMatrix
|
||||
where
|
||||
F: Fn(f64, f64) -> f64;
|
||||
|
||||
fn matrix_mul(&self, other: &Self) -> FloatMatrix;
|
||||
fn dot(&self, other: &Self) -> FloatMatrix;
|
||||
|
||||
@ -168,6 +176,36 @@ impl SeriesOps for FloatMatrix {
|
||||
fn dot(&self, other: &Self) -> FloatMatrix {
|
||||
self.matrix_mul(other)
|
||||
}
|
||||
|
||||
fn map<F>(&self, f: F) -> FloatMatrix
|
||||
where
|
||||
F: Fn(f64) -> f64,
|
||||
{
|
||||
let data = self.data().iter().map(|&v| f(v)).collect::<Vec<_>>();
|
||||
FloatMatrix::from_vec(data, self.rows(), self.cols())
|
||||
}
|
||||
|
||||
fn zip<F>(&self, other: &Self, f: F) -> FloatMatrix
|
||||
where
|
||||
F: Fn(f64, f64) -> f64,
|
||||
{
|
||||
assert!(
|
||||
self.rows() == other.rows() && self.cols() == other.cols(),
|
||||
"Matrix dimensions mismatch: left is {}x{}, right is {}x{}",
|
||||
self.rows(),
|
||||
self.cols(),
|
||||
other.rows(),
|
||||
other.cols()
|
||||
);
|
||||
|
||||
let data = self
|
||||
.data()
|
||||
.iter()
|
||||
.zip(other.data().iter())
|
||||
.map(|(&a, &b)| f(a, b))
|
||||
.collect();
|
||||
crate::matrix::Matrix::from_vec(data, self.rows(), self.cols())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@ -343,6 +381,48 @@ mod tests {
|
||||
a.dot(&b); // This should panic due to dimension mismatch
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_series_ops_map() {
|
||||
let matrix = create_float_test_matrix();
|
||||
// Map function to double each value
|
||||
let mapped_matrix = matrix.map(|x| x * 2.0);
|
||||
// Expected data after mapping
|
||||
let expected_data = vec![2.0, 4.0, 6.0, 8.0, f64::NAN, 12.0, 14.0, 16.0, f64::NAN];
|
||||
let expected_matrix = FloatMatrix::from_vec(expected_data, 3, 3);
|
||||
// assert_eq!(mapped_matrix, expected_matrix);
|
||||
for i in 0..mapped_matrix.data().len() {
|
||||
// if not nan, check equality
|
||||
if !mapped_matrix.data()[i].is_nan() {
|
||||
assert_eq!(mapped_matrix.data()[i], expected_matrix.data()[i]);
|
||||
} else {
|
||||
assert!(mapped_matrix.data()[i].is_nan());
|
||||
assert!(expected_matrix.data()[i].is_nan());
|
||||
}
|
||||
}
|
||||
assert_eq!(mapped_matrix.rows(), expected_matrix.rows());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_series_ops_zip() {
|
||||
let a = FloatMatrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2); // 2x2 matrix
|
||||
let b = FloatMatrix::from_vec(vec![5.0, 6.0, 7.0, 8.0], 2, 2); // 2x2 matrix
|
||||
// Zip function to add corresponding elements
|
||||
let zipped_matrix = a.zip(&b, |x, y| x + y);
|
||||
// Expected data after zipping
|
||||
let expected_data = vec![6.0, 8.0, 10.0, 12.0];
|
||||
let expected_matrix = FloatMatrix::from_vec(expected_data, 2, 2);
|
||||
assert_eq!(zipped_matrix, expected_matrix);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Matrix dimensions mismatch: left is 2x2, right is 3x2")]
|
||||
fn test_series_ops_zip_panic() {
|
||||
let a = FloatMatrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2); // 2x2 matrix
|
||||
let b = FloatMatrix::from_vec(vec![5.0, 6.0, 7.0, 8.0, 9.0, 10.0], 3, 2); // 3x2 matrix
|
||||
// This should panic due to dimension mismatch
|
||||
a.zip(&b, |x, y| x + y);
|
||||
}
|
||||
|
||||
// --- Edge Cases for SeriesOps ---
|
||||
|
||||
#[test]
|
||||
|
Loading…
x
Reference in New Issue
Block a user