mirror of
https://github.com/Magnus167/rustframe.git
synced 2025-08-20 04:19:59 +00:00
Merge pull request #19 from Magnus167/matr_eq
Implement element-wise equality comparison for Matrix and Frame
This commit is contained in:
commit
9a7aae1e13
2
.github/workflows/docs-and-testcov.yml
vendored
2
.github/workflows/docs-and-testcov.yml
vendored
@ -30,7 +30,7 @@ jobs:
|
||||
override: true
|
||||
|
||||
- name: Build documentation
|
||||
run: cargo doc --no-deps
|
||||
run: cargo doc --no-deps --release
|
||||
|
||||
- name: Prepare documentation for Pages
|
||||
run: |
|
||||
|
4
.github/workflows/run-unit-tests.yml
vendored
4
.github/workflows/run-unit-tests.yml
vendored
@ -21,9 +21,9 @@ jobs:
|
||||
- name: Generate code coverage
|
||||
run: cargo llvm-cov --all-features --workspace --lcov --output-path lcov.info
|
||||
- name: Run doc-tests
|
||||
run: cargo test --doc --all-features --workspace
|
||||
run: cargo test --doc --all-features --workspace --release
|
||||
- name: Test docs generation
|
||||
run: cargo doc --no-deps
|
||||
run: cargo doc --no-deps --release
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v3
|
||||
with:
|
||||
|
@ -1,17 +1,17 @@
|
||||
coverage:
|
||||
status:
|
||||
# Project coverage status: require ≥ 90% overall, allow up to 2% drop
|
||||
# Project coverage status: require >= 90% overall, allow up to 2% drop
|
||||
project:
|
||||
default:
|
||||
target: 90% # overall coverage goal
|
||||
target: 95% # overall coverage goal
|
||||
threshold: 2% # tolerable drop below target before failing
|
||||
base: auto # compare against default branch
|
||||
only_pulls: true # enforce on PRs only
|
||||
|
||||
# Patch coverage status: require ≥ 95% on new/changed lines, zero tolerance
|
||||
# Patch coverage status: require >= 95% on new/changed lines, zero tolerance
|
||||
patch:
|
||||
default:
|
||||
target: 95% # coverage goal for new or changed code
|
||||
target: 50% # coverage goal for new or changed code
|
||||
threshold: 0% # no uncovered lines allowed
|
||||
base: auto # diff against default branch
|
||||
only_pulls: true
|
||||
|
@ -8,7 +8,7 @@ pub struct Matrix<T> {
|
||||
data: Vec<T>,
|
||||
}
|
||||
|
||||
impl<T> Matrix<T> {
|
||||
impl<T: Clone> Matrix<T> {
|
||||
/// Build from columns (each inner Vec is one column)
|
||||
pub fn from_cols(cols_data: Vec<Vec<T>>) -> Self {
|
||||
let cols = cols_data.len();
|
||||
@ -38,13 +38,22 @@ impl<T> Matrix<T> {
|
||||
Matrix { rows, cols, data }
|
||||
}
|
||||
|
||||
pub fn rows(&self) -> usize {
|
||||
self.rows
|
||||
}
|
||||
pub fn data(&self) -> &[T] {
|
||||
&self.data
|
||||
}
|
||||
|
||||
pub fn data_mut(&mut self) -> &mut [T] {
|
||||
&mut self.data
|
||||
}
|
||||
|
||||
pub fn as_vec(&self) -> Vec<T> {
|
||||
self.data.clone()
|
||||
}
|
||||
|
||||
pub fn rows(&self) -> usize {
|
||||
self.rows
|
||||
}
|
||||
|
||||
pub fn cols(&self) -> usize {
|
||||
self.cols
|
||||
}
|
||||
@ -184,6 +193,77 @@ impl<'a, T> MatrixRow<'a, T> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Specifies the axis along which to perform a reduction operation.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum Axis {
|
||||
/// Apply reduction along columns (vertical axis).
|
||||
Col,
|
||||
/// Apply reduction along rows (horizontal axis).
|
||||
Row,
|
||||
}
|
||||
|
||||
/// A trait to turn either a `Matrix<T>` or a scalar T into a `Vec<T>` of
|
||||
/// length `rows*cols` (broadcasting the scalar).
|
||||
pub trait Broadcastable<T> {
|
||||
fn to_vec(&self, rows: usize, cols: usize) -> Vec<T>;
|
||||
}
|
||||
|
||||
impl<T: Clone> Broadcastable<T> for T {
|
||||
fn to_vec(&self, rows: usize, cols: usize) -> Vec<T> {
|
||||
vec![self.clone(); rows * cols]
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Clone> Broadcastable<T> for Matrix<T> {
|
||||
fn to_vec(&self, rows: usize, cols: usize) -> Vec<T> {
|
||||
assert_eq!(self.rows, rows, "row count mismatch");
|
||||
assert_eq!(self.cols, cols, "col count mismatch");
|
||||
self.data.clone()
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates element-wise eq, lt, le, gt and ge methods
|
||||
/// where the rhs can be a `Matrix<T>` or a scalar T.
|
||||
macro_rules! impl_elementwise_cmp {
|
||||
(
|
||||
$( $method:ident => $op:tt ),* $(,)?
|
||||
) => {
|
||||
impl<T: PartialOrd + Clone> Matrix<T> {
|
||||
$(
|
||||
#[doc = concat!("Element-wise comparison `self ", stringify!($op), " rhs`,\n\
|
||||
where `rhs` may be a `Matrix<T>` or a scalar T.")]
|
||||
pub fn $method<Rhs>(&self, rhs: Rhs) -> BoolMatrix
|
||||
where
|
||||
Rhs: Broadcastable<T>,
|
||||
{
|
||||
// Prepare broadcasted rhs-data
|
||||
let rhs_data = rhs.to_vec(self.rows, self.cols);
|
||||
|
||||
// Pairwise compare
|
||||
let data = self
|
||||
.data
|
||||
.iter()
|
||||
.cloned()
|
||||
.zip(rhs_data.into_iter())
|
||||
.map(|(a, b)| a $op b)
|
||||
.collect();
|
||||
|
||||
BoolMatrix::from_vec(data, self.rows, self.cols)
|
||||
}
|
||||
)*
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// Instantiate element-wise comparison implementations for matrices.
|
||||
impl_elementwise_cmp! {
|
||||
eq_elementwise => ==,
|
||||
lt_elementwise => <,
|
||||
le_elementwise => <=,
|
||||
gt_elementwise => >,
|
||||
ge_elementwise => >=,
|
||||
}
|
||||
|
||||
/// Generates element-wise arithmetic implementations for matrices.
|
||||
macro_rules! impl_elementwise_op {
|
||||
($OpTrait:ident, $method:ident, $op:tt) => {
|
||||
@ -217,10 +297,29 @@ impl_elementwise_op!(Sub, sub, -);
|
||||
impl_elementwise_op!(Mul, mul, *);
|
||||
impl_elementwise_op!(Div, div, /);
|
||||
|
||||
pub type FloatMatrix = Matrix<f64>;
|
||||
pub type BoolMatrix = Matrix<bool>;
|
||||
pub type IntMatrix = Matrix<i32>;
|
||||
pub type StringMatrix = Matrix<String>;
|
||||
/// Generates element-wise arithmetic implementations for matrices with scalars.
|
||||
macro_rules! impl_elementwise_op_scalar {
|
||||
($OpTrait:ident, $method:ident, $op:tt) => {
|
||||
impl<'a, T> std::ops::$OpTrait<T> for &'a Matrix<T>
|
||||
where
|
||||
T: Clone + std::ops::$OpTrait<Output = T>,
|
||||
{
|
||||
type Output = Matrix<T>;
|
||||
|
||||
fn $method(self, rhs: T) -> Matrix<T> {
|
||||
// Apply the operation element-wise and collect into a new matrix
|
||||
let data = self.data.iter().cloned().map(|a| a $op rhs.clone()).collect();
|
||||
Matrix { rows: self.rows, cols: self.cols, data }
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// Instantiate element-wise addition, subtraction, multiplication, and division
|
||||
impl_elementwise_op_scalar!(Add, add, +);
|
||||
impl_elementwise_op_scalar!(Sub, sub, -);
|
||||
impl_elementwise_op_scalar!(Mul, mul, *);
|
||||
impl_elementwise_op_scalar!(Div, div, /);
|
||||
|
||||
/// Generates element-wise bitwise operations for boolean matrices.
|
||||
macro_rules! impl_bitwise_op {
|
||||
@ -265,17 +364,15 @@ impl Not for Matrix<bool> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Specifies the axis along which to perform a reduction operation.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum Axis {
|
||||
/// Apply reduction along columns (vertical axis).
|
||||
Col,
|
||||
/// Apply reduction along rows (horizontal axis).
|
||||
Row,
|
||||
}
|
||||
pub type FloatMatrix = Matrix<f64>;
|
||||
pub type BoolMatrix = Matrix<bool>;
|
||||
pub type IntMatrix = Matrix<i32>;
|
||||
pub type StringMatrix = Matrix<String>;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::matrix::BoolOps;
|
||||
|
||||
use super::{BoolMatrix, FloatMatrix, Matrix, StringMatrix};
|
||||
|
||||
// Helper function to create a basic Matrix for testing
|
||||
@ -541,6 +638,22 @@ mod tests {
|
||||
assert_eq!(rows[1], vec![2, 4, 6, 8]);
|
||||
}
|
||||
|
||||
// test data_mut
|
||||
#[test]
|
||||
fn test_data_mut() {
|
||||
let mut matrix = create_test_matrix(); // 3x3
|
||||
// 1 4 7
|
||||
// 2 5 8
|
||||
// 3 6 9
|
||||
|
||||
let data_mut = matrix.data_mut();
|
||||
data_mut[0] = 10;
|
||||
data_mut[1] = 20;
|
||||
|
||||
assert_eq!(matrix[(0, 0)], 10);
|
||||
assert_eq!(matrix[(1, 0)], 20);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matrix_row_get_and_iter() {
|
||||
let matrix = create_test_matrix_2x4(); // 2x4
|
||||
@ -1162,7 +1275,6 @@ mod tests {
|
||||
assert!((div[(1, 1)] - 4.0 / 3.5).abs() < 1e-9); // 1.14...
|
||||
}
|
||||
|
||||
|
||||
fn create_test_matrix_i32() -> Matrix<i32> {
|
||||
Matrix::from_cols(vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]])
|
||||
}
|
||||
@ -1232,5 +1344,52 @@ mod tests {
|
||||
assert_eq!(matrix2.column(1), initial_col0_data_m2.as_slice());
|
||||
assert_eq!(matrix2.data(), &[4, 5, 6, 1, 2, 3, 7, 8, 9]);
|
||||
}
|
||||
// Axis enum doesn't have logic, no tests needed directly, but its presence is verified by compilation.
|
||||
|
||||
// Test broadcastable operations
|
||||
#[test]
|
||||
fn test_comparision_broadcast() {
|
||||
let matrix = create_test_matrix();
|
||||
// test all > 0
|
||||
let result = matrix.gt_elementwise(0).as_vec();
|
||||
let expected = vec![true; result.len()];
|
||||
assert_eq!(result, expected);
|
||||
|
||||
let ma = create_test_matrix();
|
||||
let mb = create_test_matrix();
|
||||
|
||||
let result = ma.eq_elementwise(mb);
|
||||
assert!(result.all());
|
||||
|
||||
let result = matrix.lt_elementwise(1e10 as i32).all();
|
||||
assert!(result);
|
||||
|
||||
for i in 0..matrix.rows() {
|
||||
for j in 0..matrix.cols() {
|
||||
let vx = matrix[(i, j)];
|
||||
let c = &(matrix.le_elementwise(vx)) & &(matrix.ge_elementwise(vx));
|
||||
assert_eq!(c.count(), 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_arithmetic_broadcast() {
|
||||
let matrix = create_test_matrix();
|
||||
let result = &matrix + 1;
|
||||
for i in 0..matrix.rows() {
|
||||
for j in 0..matrix.cols() {
|
||||
assert_eq!(result[(i, j)], matrix[(i, j)] + 1);
|
||||
}
|
||||
}
|
||||
|
||||
// test mul and div
|
||||
let result = &matrix * 2;
|
||||
let result2 = &matrix / 2;
|
||||
for i in 0..matrix.rows() {
|
||||
for j in 0..matrix.cols() {
|
||||
assert_eq!(result[(i, j)], matrix[(i, j)] * 2);
|
||||
assert_eq!(result2[(i, j)], matrix[(i, j)] / 2);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user