Merge pull request #55 from Magnus167/map-zip

Implement `map` and `zip` methods for SeriesOps and Frame
This commit is contained in:
Palash Tyagi 2025-07-06 10:56:11 +01:00 committed by GitHub
commit 5db5475a61
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 193 additions and 8 deletions

View File

@ -471,7 +471,6 @@ impl<T: Clone + PartialEq> Frame<T> {
deleted_data
}
/// Returns a new `Matrix` that is the transpose of the current frame's matrix.
pub fn transpose(&self) -> Matrix<T> {
self.matrix.transpose()
@ -513,6 +512,45 @@ impl<T: Clone + PartialEq> Frame<T> {
}
}
pub fn frame_map(&self, f: impl Fn(&T) -> T) -> Frame<T> {
Frame::new(
Matrix::from_vec(
self.matrix.data().iter().map(f).collect(),
self.matrix.rows(),
self.matrix.cols(),
),
self.column_names.clone(),
Some(self.index.clone()),
)
}
pub fn frame_zip(&self, other: &Frame<T>, f: impl Fn(&T, &T) -> T) -> Frame<T> {
if self.rows() != other.rows() || self.cols() != other.cols() {
panic!(
"Frame::frame_zip: incompatible dimensions (self: {}x{}, other: {}x{})",
self.rows(),
self.cols(),
other.rows(),
other.cols()
);
}
Frame::new(
Matrix::from_vec(
self.matrix
.data()
.iter()
.zip(other.matrix.data())
.map(|(a, b)| f(a, b))
.collect(),
self.rows(),
self.cols(),
),
self.column_names.clone(),
Some(self.index.clone()),
)
}
// Internal helpers
/// Rebuilds the column lookup map to match the current `column_names` ordering.
@ -1654,6 +1692,45 @@ mod tests {
assert_eq!(frame1.columns(), &["Z"]);
}
#[test]
fn test_frame_map() {
let frame = create_test_frame_f64(); // A=[1,2,3], B=[4,5,6]
let mapped_frame = frame.frame_map(|x| x * 2.0); // Multiply each value by 2.0
assert_eq!(mapped_frame.columns(), frame.columns());
assert_eq!(mapped_frame.index(), frame.index());
assert!((mapped_frame["A"][0] - 2.0).abs() < FLOAT_TOLERANCE);
assert!((mapped_frame["A"][1] - 4.0).abs() < FLOAT_TOLERANCE);
assert!((mapped_frame["A"][2] - 6.0).abs() < FLOAT_TOLERANCE);
assert!((mapped_frame["B"][0] - 8.0).abs() < FLOAT_TOLERANCE);
assert!((mapped_frame["B"][1] - 10.0).abs() < FLOAT_TOLERANCE);
assert!((mapped_frame["B"][2] - 12.0).abs() < FLOAT_TOLERANCE);
}
#[test]
fn test_frame_zip() {
let f1 = create_test_frame_f64(); // A=[1,2,3], B=[4,5,6]
let f2 = create_test_frame_f64_alt(); // A=[0.1,0.2,0.3], B=[0.4,0.5,0.6]
let zipped_frame = f1.frame_zip(&f2, |x, y| x + y); // Element-wise addition
assert_eq!(zipped_frame.columns(), f1.columns());
assert_eq!(zipped_frame.index(), f1.index());
assert!((zipped_frame["A"][0] - 1.1).abs() < FLOAT_TOLERANCE);
assert!((zipped_frame["A"][1] - 2.2).abs() < FLOAT_TOLERANCE);
assert!((zipped_frame["A"][2] - 3.3).abs() < FLOAT_TOLERANCE);
assert!((zipped_frame["B"][0] - 4.4).abs() < FLOAT_TOLERANCE);
assert!((zipped_frame["B"][1] - 5.5).abs() < FLOAT_TOLERANCE);
assert!((zipped_frame["B"][2] - 6.6).abs() < FLOAT_TOLERANCE);
}
#[test]
#[should_panic(expected = "Frame::frame_zip: incompatible dimensions (self: 3x1, other: 3x2)")]
fn test_frame_zip_panic() {
let mut f1 = create_test_frame_f64();
let f2 = create_test_frame_f64_alt();
f1.delete_column("B");
f1.frame_zip(&f2, |x, y| x + y); // Should panic due to different column counts
}
// --- Element-wise Arithmetic Ops Tests ---
#[test]
fn test_frame_arithmetic_ops_f64() {

View File

@ -20,6 +20,21 @@ impl SeriesOps for Frame<f64> {
{
self.matrix().apply_axis(axis, f)
}
fn map<F>(&self, f: F) -> FloatMatrix
where
F: Fn(f64) -> f64,
{
self.matrix().map(f)
}
fn zip<F>(&self, other: &Self, f: F) -> FloatMatrix
where
F: Fn(f64, f64) -> f64,
{
self.matrix().zip(other.matrix(), f)
}
fn matrix_mul(&self, other: &Self) -> FloatMatrix {
self.matrix().matrix_mul(other.matrix())
}
@ -113,7 +128,7 @@ mod tests {
let col_names = vec!["A".to_string(), "B".to_string()];
let frame = Frame::new(
Matrix::from_cols(vec![vec![1.0, 2.0], vec![3.0, 4.0]]),
col_names,
col_names.clone(),
None,
);
assert_eq!(frame.sum_vertical(), frame.matrix().sum_vertical());
@ -136,8 +151,10 @@ mod tests {
assert_eq!(frame.is_nan(), frame.matrix().is_nan());
assert_eq!(frame.apply_axis(Axis::Row, |x| x[0] + x[1]), vec![4.0, 6.0]);
assert_eq!(frame.matrix_mul(&frame), frame.matrix().matrix_mul(&frame.matrix()));
assert_eq!(
frame.matrix_mul(&frame),
frame.matrix().matrix_mul(&frame.matrix())
);
assert_eq!(frame.dot(&frame), frame.matrix().dot(&frame.matrix()));
// test transpose - returns a matrix.
@ -146,6 +163,20 @@ mod tests {
assert_eq!(frame_transposed_mat, frame_mat_transposed);
assert_eq!(frame.matrix(), &frame.matrix().transpose().transpose());
// test map
let mapped_frame = frame.map(|x| x * 2.0);
let expected_matrix = frame.matrix().map(|x| x * 2.0);
assert_eq!(mapped_frame, expected_matrix);
// test zip
let other_frame = Frame::new(
Matrix::from_cols(vec![vec![5.0, 6.0], vec![7.0, 8.0]]),
col_names.clone(),
None,
);
let zipped_frame = frame.zip(&other_frame, |x, y| x + y);
let expected_zipped_matrix = frame.matrix().zip(other_frame.matrix(), |x, y| x + y);
assert_eq!(zipped_frame, expected_zipped_matrix);
}
#[test]
@ -170,7 +201,4 @@ mod tests {
vec![false, false]
);
}
}

View File

@ -12,6 +12,14 @@ pub trait SeriesOps {
where
F: FnMut(&[f64]) -> U;
fn map<F>(&self, f: F) -> FloatMatrix
where
F: Fn(f64) -> f64;
fn zip<F>(&self, other: &Self, f: F) -> FloatMatrix
where
F: Fn(f64, f64) -> f64;
fn matrix_mul(&self, other: &Self) -> FloatMatrix;
fn dot(&self, other: &Self) -> FloatMatrix;
@ -168,6 +176,36 @@ impl SeriesOps for FloatMatrix {
fn dot(&self, other: &Self) -> FloatMatrix {
self.matrix_mul(other)
}
fn map<F>(&self, f: F) -> FloatMatrix
where
F: Fn(f64) -> f64,
{
let data = self.data().iter().map(|&v| f(v)).collect::<Vec<_>>();
FloatMatrix::from_vec(data, self.rows(), self.cols())
}
fn zip<F>(&self, other: &Self, f: F) -> FloatMatrix
where
F: Fn(f64, f64) -> f64,
{
assert!(
self.rows() == other.rows() && self.cols() == other.cols(),
"Matrix dimensions mismatch: left is {}x{}, right is {}x{}",
self.rows(),
self.cols(),
other.rows(),
other.cols()
);
let data = self
.data()
.iter()
.zip(other.data().iter())
.map(|(&a, &b)| f(a, b))
.collect();
crate::matrix::Matrix::from_vec(data, self.rows(), self.cols())
}
}
#[cfg(test)]
@ -343,6 +381,48 @@ mod tests {
a.dot(&b); // This should panic due to dimension mismatch
}
#[test]
fn test_series_ops_map() {
let matrix = create_float_test_matrix();
// Map function to double each value
let mapped_matrix = matrix.map(|x| x * 2.0);
// Expected data after mapping
let expected_data = vec![2.0, 4.0, 6.0, 8.0, f64::NAN, 12.0, 14.0, 16.0, f64::NAN];
let expected_matrix = FloatMatrix::from_vec(expected_data, 3, 3);
// assert_eq!(mapped_matrix, expected_matrix);
for i in 0..mapped_matrix.data().len() {
// if not nan, check equality
if !mapped_matrix.data()[i].is_nan() {
assert_eq!(mapped_matrix.data()[i], expected_matrix.data()[i]);
} else {
assert!(mapped_matrix.data()[i].is_nan());
assert!(expected_matrix.data()[i].is_nan());
}
}
assert_eq!(mapped_matrix.rows(), expected_matrix.rows());
}
#[test]
fn test_series_ops_zip() {
let a = FloatMatrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2); // 2x2 matrix
let b = FloatMatrix::from_vec(vec![5.0, 6.0, 7.0, 8.0], 2, 2); // 2x2 matrix
// Zip function to add corresponding elements
let zipped_matrix = a.zip(&b, |x, y| x + y);
// Expected data after zipping
let expected_data = vec![6.0, 8.0, 10.0, 12.0];
let expected_matrix = FloatMatrix::from_vec(expected_data, 2, 2);
assert_eq!(zipped_matrix, expected_matrix);
}
#[test]
#[should_panic(expected = "Matrix dimensions mismatch: left is 2x2, right is 3x2")]
fn test_series_ops_zip_panic() {
let a = FloatMatrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2); // 2x2 matrix
let b = FloatMatrix::from_vec(vec![5.0, 6.0, 7.0, 8.0, 9.0, 10.0], 3, 2); // 3x2 matrix
// This should panic due to dimension mismatch
a.zip(&b, |x, y| x + y);
}
// --- Edge Cases for SeriesOps ---
#[test]