mirror of
https://github.com/Magnus167/rustframe.git
synced 2025-08-20 04:00:01 +00:00
Merge pull request #45 from Magnus167/frame_ops_enhance
Enhance element-wise and bitwise operations for Frame<T> and Frame<bool>
This commit is contained in:
commit
092aebd5fc
@ -781,14 +781,13 @@ impl<T: Clone + PartialEq> IndexMut<&str> for Frame<T> {
|
|||||||
/// Panics if column labels or row indices differ between operands.
|
/// Panics if column labels or row indices differ between operands.
|
||||||
macro_rules! impl_elementwise_frame_op {
|
macro_rules! impl_elementwise_frame_op {
|
||||||
($OpTrait:ident, $method:ident) => {
|
($OpTrait:ident, $method:ident) => {
|
||||||
|
// &Frame<T> $OpTrait &Frame<T>
|
||||||
impl<'a, 'b, T> std::ops::$OpTrait<&'b Frame<T>> for &'a Frame<T>
|
impl<'a, 'b, T> std::ops::$OpTrait<&'b Frame<T>> for &'a Frame<T>
|
||||||
where
|
where
|
||||||
T: Clone + PartialEq + std::ops::$OpTrait<Output = T>,
|
T: Clone + PartialEq + std::ops::$OpTrait<Output = T>,
|
||||||
{
|
{
|
||||||
type Output = Frame<T>;
|
type Output = Frame<T>;
|
||||||
|
|
||||||
fn $method(self, rhs: &'b Frame<T>) -> Frame<T> {
|
fn $method(self, rhs: &'b Frame<T>) -> Frame<T> {
|
||||||
// Verify matching schema
|
|
||||||
if self.column_names != rhs.column_names {
|
if self.column_names != rhs.column_names {
|
||||||
panic!(
|
panic!(
|
||||||
"Element-wise {}: column names do not match. Left: {:?}, Right: {:?}",
|
"Element-wise {}: column names do not match. Left: {:?}, Right: {:?}",
|
||||||
@ -805,21 +804,47 @@ macro_rules! impl_elementwise_frame_op {
|
|||||||
rhs.index
|
rhs.index
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply the matrix operation
|
|
||||||
let result_matrix = (&self.matrix).$method(&rhs.matrix);
|
let result_matrix = (&self.matrix).$method(&rhs.matrix);
|
||||||
|
|
||||||
// Determine index for the result
|
|
||||||
let new_index = match self.index {
|
let new_index = match self.index {
|
||||||
RowIndex::Range(_) => None,
|
RowIndex::Range(_) => None,
|
||||||
_ => Some(self.index.clone()),
|
_ => Some(self.index.clone()),
|
||||||
};
|
};
|
||||||
|
|
||||||
Frame::new(result_matrix, self.column_names.clone(), new_index)
|
Frame::new(result_matrix, self.column_names.clone(), new_index)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Frame<T> $OpTrait &Frame<T>
|
||||||
|
impl<'b, T> std::ops::$OpTrait<&'b Frame<T>> for Frame<T>
|
||||||
|
where
|
||||||
|
T: Clone + PartialEq + std::ops::$OpTrait<Output = T>,
|
||||||
|
{
|
||||||
|
type Output = Frame<T>;
|
||||||
|
fn $method(self, rhs: &'b Frame<T>) -> Frame<T> {
|
||||||
|
(&self).$method(rhs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// &Frame<T> $OpTrait Frame<T>
|
||||||
|
impl<'a, T> std::ops::$OpTrait<Frame<T>> for &'a Frame<T>
|
||||||
|
where
|
||||||
|
T: Clone + PartialEq + std::ops::$OpTrait<Output = T>,
|
||||||
|
{
|
||||||
|
type Output = Frame<T>;
|
||||||
|
fn $method(self, rhs: Frame<T>) -> Frame<T> {
|
||||||
|
self.$method(&rhs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Frame<T> $OpTrait Frame<T>
|
||||||
|
impl<T> std::ops::$OpTrait<Frame<T>> for Frame<T>
|
||||||
|
where
|
||||||
|
T: Clone + PartialEq + std::ops::$OpTrait<Output = T>,
|
||||||
|
{
|
||||||
|
type Output = Frame<T>;
|
||||||
|
fn $method(self, rhs: Frame<T>) -> Frame<T> {
|
||||||
|
(&self).$method(&rhs)
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
impl_elementwise_frame_op!(Add, add);
|
impl_elementwise_frame_op!(Add, add);
|
||||||
impl_elementwise_frame_op!(Sub, sub);
|
impl_elementwise_frame_op!(Sub, sub);
|
||||||
impl_elementwise_frame_op!(Mul, mul);
|
impl_elementwise_frame_op!(Mul, mul);
|
||||||
@ -830,11 +855,10 @@ impl_elementwise_frame_op!(Div, div);
|
|||||||
/// Panics if column labels or row indices differ between operands.
|
/// Panics if column labels or row indices differ between operands.
|
||||||
macro_rules! impl_bitwise_frame_op {
|
macro_rules! impl_bitwise_frame_op {
|
||||||
($OpTrait:ident, $method:ident) => {
|
($OpTrait:ident, $method:ident) => {
|
||||||
|
// &Frame<bool> $OpTrait &Frame<bool>
|
||||||
impl<'a, 'b> std::ops::$OpTrait<&'b Frame<bool>> for &'a Frame<bool> {
|
impl<'a, 'b> std::ops::$OpTrait<&'b Frame<bool>> for &'a Frame<bool> {
|
||||||
type Output = Frame<bool>;
|
type Output = Frame<bool>;
|
||||||
|
|
||||||
fn $method(self, rhs: &'b Frame<bool>) -> Frame<bool> {
|
fn $method(self, rhs: &'b Frame<bool>) -> Frame<bool> {
|
||||||
// Verify matching schema
|
|
||||||
if self.column_names != rhs.column_names {
|
if self.column_names != rhs.column_names {
|
||||||
panic!(
|
panic!(
|
||||||
"Bitwise {}: column names do not match. Left: {:?}, Right: {:?}",
|
"Bitwise {}: column names do not match. Left: {:?}, Right: {:?}",
|
||||||
@ -851,25 +875,43 @@ macro_rules! impl_bitwise_frame_op {
|
|||||||
rhs.index
|
rhs.index
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply the matrix operation
|
|
||||||
let result_matrix = (&self.matrix).$method(&rhs.matrix);
|
let result_matrix = (&self.matrix).$method(&rhs.matrix);
|
||||||
|
|
||||||
// Determine index for the result
|
|
||||||
let new_index = match self.index {
|
let new_index = match self.index {
|
||||||
RowIndex::Range(_) => None,
|
RowIndex::Range(_) => None,
|
||||||
_ => Some(self.index.clone()),
|
_ => Some(self.index.clone()),
|
||||||
};
|
};
|
||||||
|
|
||||||
Frame::new(result_matrix, self.column_names.clone(), new_index)
|
Frame::new(result_matrix, self.column_names.clone(), new_index)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Frame<bool> $OpTrait &Frame<bool>
|
||||||
|
impl<'b> std::ops::$OpTrait<&'b Frame<bool>> for Frame<bool> {
|
||||||
|
type Output = Frame<bool>;
|
||||||
|
fn $method(self, rhs: &'b Frame<bool>) -> Frame<bool> {
|
||||||
|
(&self).$method(rhs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// &Frame<bool> $OpTrait Frame<bool>
|
||||||
|
impl<'a> std::ops::$OpTrait<Frame<bool>> for &'a Frame<bool> {
|
||||||
|
type Output = Frame<bool>;
|
||||||
|
fn $method(self, rhs: Frame<bool>) -> Frame<bool> {
|
||||||
|
self.$method(&rhs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Frame<bool> $OpTrait Frame<bool>
|
||||||
|
impl std::ops::$OpTrait<Frame<bool>> for Frame<bool> {
|
||||||
|
type Output = Frame<bool>;
|
||||||
|
fn $method(self, rhs: Frame<bool>) -> Frame<bool> {
|
||||||
|
(&self).$method(&rhs)
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
impl_bitwise_frame_op!(BitAnd, bitand);
|
impl_bitwise_frame_op!(BitAnd, bitand);
|
||||||
impl_bitwise_frame_op!(BitOr, bitor);
|
impl_bitwise_frame_op!(BitOr, bitor);
|
||||||
impl_bitwise_frame_op!(BitXor, bitxor);
|
impl_bitwise_frame_op!(BitXor, bitxor);
|
||||||
|
|
||||||
|
/* ---------- Logical NOT ---------- */
|
||||||
/// Implements logical NOT (`!`) for `Frame<bool>`, consuming the frame.
|
/// Implements logical NOT (`!`) for `Frame<bool>`, consuming the frame.
|
||||||
impl Not for Frame<bool> {
|
impl Not for Frame<bool> {
|
||||||
type Output = Frame<bool>;
|
type Output = Frame<bool>;
|
||||||
@ -888,12 +930,30 @@ impl Not for Frame<bool> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- Tests ---
|
/// Implements logical NOT (`!`) for `&Frame<bool>`, borrowing the frame.
|
||||||
|
impl Not for &Frame<bool> {
|
||||||
|
type Output = Frame<bool>;
|
||||||
|
|
||||||
|
fn not(self) -> Frame<bool> {
|
||||||
|
// Apply NOT to the underlying matrix
|
||||||
|
let result_matrix = !&self.matrix;
|
||||||
|
|
||||||
|
// Determine index for the result
|
||||||
|
let new_index = match self.index {
|
||||||
|
RowIndex::Range(_) => None,
|
||||||
|
_ => Some(self.index.clone()),
|
||||||
|
};
|
||||||
|
|
||||||
|
Frame::new(result_matrix, self.column_names.clone(), new_index)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Tests ---
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
// Assume Matrix is available from crate::matrix or similar
|
// Assume Matrix is available from crate::matrix or similar
|
||||||
use crate::matrix::Matrix;
|
use crate::matrix::{BoolOps, Matrix};
|
||||||
use chrono::NaiveDate;
|
use chrono::NaiveDate;
|
||||||
// HashMap needed for direct inspection in tests if required
|
// HashMap needed for direct inspection in tests if required
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
@ -1295,7 +1355,7 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn frame_row_mutate_default_index() {
|
fn frame_row_mutate_default_index() {
|
||||||
let mut frame = create_test_frame_f64(); // Index 0..3, A=[1,2,3], B=[4,5,6]
|
let mut frame = create_test_frame_f64(); // Index 0..3, A=[1,2,3], B=[4,5,6]
|
||||||
// Mutate using set("col_name", value)
|
// Mutate using set("col_name", value)
|
||||||
frame.get_row_mut(1).set("A", 2.9); // Mutate row index 1, col A
|
frame.get_row_mut(1).set("A", 2.9); // Mutate row index 1, col A
|
||||||
assert_eq!(frame["A"], vec![1.0, 2.9, 3.0]);
|
assert_eq!(frame["A"], vec![1.0, 2.9, 3.0]);
|
||||||
// Mutate using IndexMut by physical column index
|
// Mutate using IndexMut by physical column index
|
||||||
@ -1349,7 +1409,7 @@ mod tests {
|
|||||||
fn test_row_view_name_panic() {
|
fn test_row_view_name_panic() {
|
||||||
let frame = create_test_frame_f64();
|
let frame = create_test_frame_f64();
|
||||||
let row_view = frame.get_row(0);
|
let row_view = frame.get_row(0);
|
||||||
let _ = row_view["C"]; // Access non-existent column name
|
let _ = row_view["C"]; // Access non-existent column Z
|
||||||
}
|
}
|
||||||
#[test]
|
#[test]
|
||||||
#[should_panic(expected = "column index 3 out of bounds")] // Check specific message
|
#[should_panic(expected = "column index 3 out of bounds")] // Check specific message
|
||||||
@ -1438,7 +1498,7 @@ mod tests {
|
|||||||
assert_eq!(frame.cols(), 2);
|
assert_eq!(frame.cols(), 2);
|
||||||
assert_eq!(frame["X"], vec![1, 2]); // X data unchanged
|
assert_eq!(frame["X"], vec![1, 2]); // X data unchanged
|
||||||
assert_eq!(frame["B"], vec![5, 6]); // B data unchanged
|
assert_eq!(frame["B"], vec![5, 6]); // B data unchanged
|
||||||
// Check internal state after delete + rebuild_col_lookup
|
// Check internal state after delete + rebuild_col_lookup
|
||||||
assert_eq!(frame.column_index("X"), Some(0)); // X is now physical col 0
|
assert_eq!(frame.column_index("X"), Some(0)); // X is now physical col 0
|
||||||
assert_eq!(frame.column_index("B"), Some(1)); // B is now physical col 1
|
assert_eq!(frame.column_index("B"), Some(1)); // B is now physical col 1
|
||||||
assert!(frame.column_index("A").is_none());
|
assert!(frame.column_index("A").is_none());
|
||||||
@ -1447,10 +1507,10 @@ mod tests {
|
|||||||
// Sort Columns [X, B] -> [B, X]
|
// Sort Columns [X, B] -> [B, X]
|
||||||
frame.sort_columns();
|
frame.sort_columns();
|
||||||
assert_eq!(frame.columns(), &["B", "X"]); // Alphabetical order of names
|
assert_eq!(frame.columns(), &["B", "X"]); // Alphabetical order of names
|
||||||
// Verify data remained with the correct logical column after sort
|
// Verify data remained with the correct logical column after sort
|
||||||
assert_eq!(frame["B"], vec![5, 6], "Data in B after sort"); // B should still have [5, 6]
|
assert_eq!(frame["B"], vec![5, 6], "Data in B after sort"); // B should still have [5, 6]
|
||||||
assert_eq!(frame["X"], vec![1, 2], "Data in X after sort"); // X should still have [1, 2]
|
assert_eq!(frame["X"], vec![1, 2], "Data in X after sort"); // X should still have [1, 2]
|
||||||
// Verify internal lookup map is correct after sort
|
// Verify internal lookup map is correct after sort
|
||||||
assert_eq!(frame.column_index("B"), Some(0), "Index of B after sort"); // B is now physical col 0
|
assert_eq!(frame.column_index("B"), Some(0), "Index of B after sort"); // B is now physical col 0
|
||||||
assert_eq!(frame.column_index("X"), Some(1), "Index of X after sort"); // X is now physical col 1
|
assert_eq!(frame.column_index("X"), Some(1), "Index of X after sort"); // X is now physical col 1
|
||||||
assert_eq!(frame.col_lookup.len(), 2);
|
assert_eq!(frame.col_lookup.len(), 2);
|
||||||
@ -1666,6 +1726,79 @@ mod tests {
|
|||||||
assert_eq!(frame_div["Y"], vec![10, -10]);
|
assert_eq!(frame_div["Y"], vec![10, -10]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn tests_for_frame_arithmetic_ops() {
|
||||||
|
let ops: Vec<(
|
||||||
|
&str,
|
||||||
|
fn(&Frame<f64>, &Frame<f64>) -> Frame<f64>,
|
||||||
|
fn(&Frame<f64>, &Frame<f64>) -> Frame<f64>,
|
||||||
|
)> = vec![
|
||||||
|
("addition", |a, b| a + b, |a, b| (&*a) + (&*b)),
|
||||||
|
("subtraction", |a, b| a - b, |a, b| (&*a) - (&*b)),
|
||||||
|
("multiplication", |a, b| a * b, |a, b| (&*a) * (&*b)),
|
||||||
|
("division", |a, b| a / b, |a, b| (&*a) / (&*b)),
|
||||||
|
];
|
||||||
|
|
||||||
|
for (op_name, owned_op, ref_op) in ops {
|
||||||
|
let f1 = create_test_frame_f64();
|
||||||
|
let f2 = create_test_frame_f64_alt();
|
||||||
|
let result_owned = owned_op(&f1, &f2);
|
||||||
|
let expected = ref_op(&f1, &f2);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
result_owned.columns(),
|
||||||
|
f1.columns(),
|
||||||
|
"Column mismatch for {}",
|
||||||
|
op_name
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
result_owned.index(),
|
||||||
|
f1.index(),
|
||||||
|
"Index mismatch for {}",
|
||||||
|
op_name
|
||||||
|
);
|
||||||
|
|
||||||
|
let bool_mat = result_owned.matrix().eq_elem(expected.matrix().clone());
|
||||||
|
assert!(bool_mat.all(), "Element-wise {} failed", op_name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// test not , and or on frame
|
||||||
|
#[test]
|
||||||
|
fn tests_for_frame_bool_ops() {
|
||||||
|
let ops: Vec<(
|
||||||
|
&str,
|
||||||
|
fn(&Frame<bool>, &Frame<bool>) -> Frame<bool>,
|
||||||
|
fn(&Frame<bool>, &Frame<bool>) -> Frame<bool>,
|
||||||
|
)> = vec![
|
||||||
|
("and", |a, b| a & b, |a, b| (&*a) & (&*b)),
|
||||||
|
("or", |a, b| a | b, |a, b| (&*a) | (&*b)),
|
||||||
|
("xor", |a, b| a ^ b, |a, b| (&*a) ^ (&*b)),
|
||||||
|
];
|
||||||
|
for (op_name, owned_op, ref_op) in ops {
|
||||||
|
let f1 = create_test_frame_bool();
|
||||||
|
let f2 = create_test_frame_bool_alt();
|
||||||
|
let result_owned = owned_op(&f1, &f2);
|
||||||
|
let expected = ref_op(&f1, &f2);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
result_owned.columns(),
|
||||||
|
f1.columns(),
|
||||||
|
"Column mismatch for {}",
|
||||||
|
op_name
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
result_owned.index(),
|
||||||
|
f1.index(),
|
||||||
|
"Index mismatch for {}",
|
||||||
|
op_name
|
||||||
|
);
|
||||||
|
|
||||||
|
let bool_mat = result_owned.matrix().eq_elem(expected.matrix().clone());
|
||||||
|
assert!(bool_mat.all(), "Element-wise {} failed", op_name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_frame_arithmetic_ops_date_index() {
|
fn test_frame_arithmetic_ops_date_index() {
|
||||||
let dates = vec![d(2024, 1, 1), d(2024, 1, 2)];
|
let dates = vec![d(2024, 1, 1), d(2024, 1, 2)];
|
||||||
|
@ -40,7 +40,7 @@ pub struct BDatesList {
|
|||||||
/// ```rust
|
/// ```rust
|
||||||
/// use chrono::NaiveDate;
|
/// use chrono::NaiveDate;
|
||||||
/// use std::error::Error;
|
/// use std::error::Error;
|
||||||
/// use rustframe::utils::{BDatesList, DateFreq}; // Replace bdates with your actual crate/module name
|
/// use rustframe::utils::{BDatesList, DateFreq};
|
||||||
///
|
///
|
||||||
/// fn main() -> Result<(), Box<dyn Error>> {
|
/// fn main() -> Result<(), Box<dyn Error>> {
|
||||||
/// let start_date = "2023-11-01".to_string(); // Wednesday
|
/// let start_date = "2023-11-01".to_string(); // Wednesday
|
||||||
@ -68,7 +68,7 @@ pub struct BDatesList {
|
|||||||
/// ```rust
|
/// ```rust
|
||||||
/// use chrono::NaiveDate;
|
/// use chrono::NaiveDate;
|
||||||
/// use std::error::Error;
|
/// use std::error::Error;
|
||||||
/// use rustframe::utils::{BDatesList, DateFreq}; // Replace bdates with your actual crate/module name
|
/// use rustframe::utils::{BDatesList, DateFreq};
|
||||||
///
|
///
|
||||||
/// fn main() -> Result<(), Box<dyn Error>> {
|
/// fn main() -> Result<(), Box<dyn Error>> {
|
||||||
/// let start_date = "2024-02-28".to_string(); // Wednesday
|
/// let start_date = "2024-02-28".to_string(); // Wednesday
|
||||||
@ -98,7 +98,7 @@ pub struct BDatesList {
|
|||||||
/// ```rust
|
/// ```rust
|
||||||
/// use chrono::NaiveDate;
|
/// use chrono::NaiveDate;
|
||||||
/// use std::error::Error;
|
/// use std::error::Error;
|
||||||
/// use rustframe::utils::{BDatesList, DateFreq}; // Replace bdates with your actual crate/module name
|
/// use rustframe::utils::{BDatesList, DateFreq};
|
||||||
///
|
///
|
||||||
/// fn main() -> Result<(), Box<dyn Error>> {
|
/// fn main() -> Result<(), Box<dyn Error>> {
|
||||||
/// let start_date = "2023-11-20".to_string(); // Mon, Week 47
|
/// let start_date = "2023-11-20".to_string(); // Mon, Week 47
|
||||||
@ -294,7 +294,7 @@ impl BDatesList {
|
|||||||
/// ```rust
|
/// ```rust
|
||||||
/// use chrono::NaiveDate;
|
/// use chrono::NaiveDate;
|
||||||
/// use std::error::Error;
|
/// use std::error::Error;
|
||||||
/// use rustframe::utils::{BDatesGenerator, DateFreq}; // Replace bdates with your actual crate/module name
|
/// use rustframe::utils::{BDatesGenerator, DateFreq};
|
||||||
///
|
///
|
||||||
/// fn main() -> Result<(), Box<dyn Error>> {
|
/// fn main() -> Result<(), Box<dyn Error>> {
|
||||||
/// let start = NaiveDate::from_ymd_opt(2024, 4, 29).unwrap(); // Monday
|
/// let start = NaiveDate::from_ymd_opt(2024, 4, 29).unwrap(); // Monday
|
||||||
|
@ -160,7 +160,7 @@ enum GroupKey {
|
|||||||
/// ```rust
|
/// ```rust
|
||||||
/// use chrono::NaiveDate;
|
/// use chrono::NaiveDate;
|
||||||
/// use std::error::Error;
|
/// use std::error::Error;
|
||||||
/// # use rustframe::utils::{DatesList, DateFreq}; // Assuming the crate/module is named 'dates'
|
/// use rustframe::utils::{DatesList, DateFreq};
|
||||||
///
|
///
|
||||||
/// # fn main() -> Result<(), Box<dyn Error>> {
|
/// # fn main() -> Result<(), Box<dyn Error>> {
|
||||||
/// let start_date = "2023-11-01".to_string(); // Wednesday
|
/// let start_date = "2023-11-01".to_string(); // Wednesday
|
||||||
|
Loading…
x
Reference in New Issue
Block a user