Merge pull request #19 from Magnus167/matr_eq

Implement element-wise equality comparison for Matrix and Frame
This commit is contained in:
Palash Tyagi 2025-04-26 02:56:07 +01:00 committed by GitHub
commit 9a7aae1e13
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 189 additions and 30 deletions

View File

@ -30,7 +30,7 @@ jobs:
override: true
- name: Build documentation
run: cargo doc --no-deps
run: cargo doc --no-deps --release
- name: Prepare documentation for Pages
run: |

View File

@ -21,9 +21,9 @@ jobs:
- name: Generate code coverage
run: cargo llvm-cov --all-features --workspace --lcov --output-path lcov.info
- name: Run doc-tests
run: cargo test --doc --all-features --workspace
run: cargo test --doc --all-features --workspace --release
- name: Test docs generation
run: cargo doc --no-deps
run: cargo doc --no-deps --release
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:

View File

@ -1,17 +1,17 @@
coverage:
status:
# Project coverage status: require ≥ 90% overall, allow up to 2% drop
# Project coverage status: require >= 90% overall, allow up to 2% drop
project:
default:
target: 90% # overall coverage goal
threshold: 2% # tolerable drop below target before failing
base: auto # compare against default branch
target: 95% # overall coverage goal
threshold: 2% # tolerable drop below target before failing
base: auto # compare against default branch
only_pulls: true # enforce on PRs only
# Patch coverage status: require ≥ 95% on new/changed lines, zero tolerance
# Patch coverage status: require >= 95% on new/changed lines, zero tolerance
patch:
default:
target: 95% # coverage goal for new or changed code
threshold: 0% # no uncovered lines allowed
base: auto # diff against default branch
target: 50% # coverage goal for new or changed code
threshold: 0% # no uncovered lines allowed
base: auto # diff against default branch
only_pulls: true

View File

@ -8,7 +8,7 @@ pub struct Matrix<T> {
data: Vec<T>,
}
impl<T> Matrix<T> {
impl<T: Clone> Matrix<T> {
/// Build from columns (each inner Vec is one column)
pub fn from_cols(cols_data: Vec<Vec<T>>) -> Self {
let cols = cols_data.len();
@ -38,13 +38,22 @@ impl<T> Matrix<T> {
Matrix { rows, cols, data }
}
pub fn rows(&self) -> usize {
self.rows
}
pub fn data(&self) -> &[T] {
&self.data
}
pub fn data_mut(&mut self) -> &mut [T] {
&mut self.data
}
pub fn as_vec(&self) -> Vec<T> {
self.data.clone()
}
pub fn rows(&self) -> usize {
self.rows
}
pub fn cols(&self) -> usize {
self.cols
}
@ -184,6 +193,77 @@ impl<'a, T> MatrixRow<'a, T> {
}
}
/// Specifies the axis along which to perform a reduction operation.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Axis {
/// Apply reduction along columns (vertical axis).
Col,
/// Apply reduction along rows (horizontal axis).
Row,
}
/// A trait to turn either a `Matrix<T>` or a scalar T into a `Vec<T>` of
/// length `rows*cols` (broadcasting the scalar).
pub trait Broadcastable<T> {
fn to_vec(&self, rows: usize, cols: usize) -> Vec<T>;
}
impl<T: Clone> Broadcastable<T> for T {
fn to_vec(&self, rows: usize, cols: usize) -> Vec<T> {
vec![self.clone(); rows * cols]
}
}
impl<T: Clone> Broadcastable<T> for Matrix<T> {
fn to_vec(&self, rows: usize, cols: usize) -> Vec<T> {
assert_eq!(self.rows, rows, "row count mismatch");
assert_eq!(self.cols, cols, "col count mismatch");
self.data.clone()
}
}
/// Generates element-wise eq, lt, le, gt and ge methods
/// where the rhs can be a `Matrix<T>` or a scalar T.
macro_rules! impl_elementwise_cmp {
(
$( $method:ident => $op:tt ),* $(,)?
) => {
impl<T: PartialOrd + Clone> Matrix<T> {
$(
#[doc = concat!("Element-wise comparison `self ", stringify!($op), " rhs`,\n\
where `rhs` may be a `Matrix<T>` or a scalar T.")]
pub fn $method<Rhs>(&self, rhs: Rhs) -> BoolMatrix
where
Rhs: Broadcastable<T>,
{
// Prepare broadcasted rhs-data
let rhs_data = rhs.to_vec(self.rows, self.cols);
// Pairwise compare
let data = self
.data
.iter()
.cloned()
.zip(rhs_data.into_iter())
.map(|(a, b)| a $op b)
.collect();
BoolMatrix::from_vec(data, self.rows, self.cols)
}
)*
}
};
}
// Instantiate element-wise comparison implementations for matrices.
impl_elementwise_cmp! {
eq_elementwise => ==,
lt_elementwise => <,
le_elementwise => <=,
gt_elementwise => >,
ge_elementwise => >=,
}
/// Generates element-wise arithmetic implementations for matrices.
macro_rules! impl_elementwise_op {
($OpTrait:ident, $method:ident, $op:tt) => {
@ -217,10 +297,29 @@ impl_elementwise_op!(Sub, sub, -);
impl_elementwise_op!(Mul, mul, *);
impl_elementwise_op!(Div, div, /);
pub type FloatMatrix = Matrix<f64>;
pub type BoolMatrix = Matrix<bool>;
pub type IntMatrix = Matrix<i32>;
pub type StringMatrix = Matrix<String>;
/// Generates element-wise arithmetic implementations for matrices with scalars.
macro_rules! impl_elementwise_op_scalar {
($OpTrait:ident, $method:ident, $op:tt) => {
impl<'a, T> std::ops::$OpTrait<T> for &'a Matrix<T>
where
T: Clone + std::ops::$OpTrait<Output = T>,
{
type Output = Matrix<T>;
fn $method(self, rhs: T) -> Matrix<T> {
// Apply the operation element-wise and collect into a new matrix
let data = self.data.iter().cloned().map(|a| a $op rhs.clone()).collect();
Matrix { rows: self.rows, cols: self.cols, data }
}
}
};
}
// Instantiate element-wise addition, subtraction, multiplication, and division
impl_elementwise_op_scalar!(Add, add, +);
impl_elementwise_op_scalar!(Sub, sub, -);
impl_elementwise_op_scalar!(Mul, mul, *);
impl_elementwise_op_scalar!(Div, div, /);
/// Generates element-wise bitwise operations for boolean matrices.
macro_rules! impl_bitwise_op {
@ -265,17 +364,15 @@ impl Not for Matrix<bool> {
}
}
/// Specifies the axis along which to perform a reduction operation.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Axis {
/// Apply reduction along columns (vertical axis).
Col,
/// Apply reduction along rows (horizontal axis).
Row,
}
pub type FloatMatrix = Matrix<f64>;
pub type BoolMatrix = Matrix<bool>;
pub type IntMatrix = Matrix<i32>;
pub type StringMatrix = Matrix<String>;
#[cfg(test)]
mod tests {
use crate::matrix::BoolOps;
use super::{BoolMatrix, FloatMatrix, Matrix, StringMatrix};
// Helper function to create a basic Matrix for testing
@ -541,6 +638,22 @@ mod tests {
assert_eq!(rows[1], vec![2, 4, 6, 8]);
}
// test data_mut
#[test]
fn test_data_mut() {
let mut matrix = create_test_matrix(); // 3x3
// 1 4 7
// 2 5 8
// 3 6 9
let data_mut = matrix.data_mut();
data_mut[0] = 10;
data_mut[1] = 20;
assert_eq!(matrix[(0, 0)], 10);
assert_eq!(matrix[(1, 0)], 20);
}
#[test]
fn test_matrix_row_get_and_iter() {
let matrix = create_test_matrix_2x4(); // 2x4
@ -1162,7 +1275,6 @@ mod tests {
assert!((div[(1, 1)] - 4.0 / 3.5).abs() < 1e-9); // 1.14...
}
fn create_test_matrix_i32() -> Matrix<i32> {
Matrix::from_cols(vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]])
}
@ -1232,5 +1344,52 @@ mod tests {
assert_eq!(matrix2.column(1), initial_col0_data_m2.as_slice());
assert_eq!(matrix2.data(), &[4, 5, 6, 1, 2, 3, 7, 8, 9]);
}
// Axis enum doesn't have logic, no tests needed directly, but its presence is verified by compilation.
// Test broadcastable operations
#[test]
fn test_comparision_broadcast() {
let matrix = create_test_matrix();
// test all > 0
let result = matrix.gt_elementwise(0).as_vec();
let expected = vec![true; result.len()];
assert_eq!(result, expected);
let ma = create_test_matrix();
let mb = create_test_matrix();
let result = ma.eq_elementwise(mb);
assert!(result.all());
let result = matrix.lt_elementwise(1e10 as i32).all();
assert!(result);
for i in 0..matrix.rows() {
for j in 0..matrix.cols() {
let vx = matrix[(i, j)];
let c = &(matrix.le_elementwise(vx)) & &(matrix.ge_elementwise(vx));
assert_eq!(c.count(), 1);
}
}
}
#[test]
fn test_arithmetic_broadcast() {
let matrix = create_test_matrix();
let result = &matrix + 1;
for i in 0..matrix.rows() {
for j in 0..matrix.cols() {
assert_eq!(result[(i, j)], matrix[(i, j)] + 1);
}
}
// test mul and div
let result = &matrix * 2;
let result2 = &matrix / 2;
for i in 0..matrix.rows() {
for j in 0..matrix.cols() {
assert_eq!(result[(i, j)], matrix[(i, j)] * 2);
assert_eq!(result2[(i, j)], matrix[(i, j)] / 2);
}
}
}
}