mirror of
https://github.com/Magnus167/rustframe.git
synced 2025-08-20 04:19:59 +00:00
Merge pull request #19 from Magnus167/matr_eq
Implement element-wise equality comparison for Matrix and Frame
This commit is contained in:
commit
9a7aae1e13
2
.github/workflows/docs-and-testcov.yml
vendored
2
.github/workflows/docs-and-testcov.yml
vendored
@ -30,7 +30,7 @@ jobs:
|
|||||||
override: true
|
override: true
|
||||||
|
|
||||||
- name: Build documentation
|
- name: Build documentation
|
||||||
run: cargo doc --no-deps
|
run: cargo doc --no-deps --release
|
||||||
|
|
||||||
- name: Prepare documentation for Pages
|
- name: Prepare documentation for Pages
|
||||||
run: |
|
run: |
|
||||||
|
4
.github/workflows/run-unit-tests.yml
vendored
4
.github/workflows/run-unit-tests.yml
vendored
@ -21,9 +21,9 @@ jobs:
|
|||||||
- name: Generate code coverage
|
- name: Generate code coverage
|
||||||
run: cargo llvm-cov --all-features --workspace --lcov --output-path lcov.info
|
run: cargo llvm-cov --all-features --workspace --lcov --output-path lcov.info
|
||||||
- name: Run doc-tests
|
- name: Run doc-tests
|
||||||
run: cargo test --doc --all-features --workspace
|
run: cargo test --doc --all-features --workspace --release
|
||||||
- name: Test docs generation
|
- name: Test docs generation
|
||||||
run: cargo doc --no-deps
|
run: cargo doc --no-deps --release
|
||||||
- name: Upload coverage to Codecov
|
- name: Upload coverage to Codecov
|
||||||
uses: codecov/codecov-action@v3
|
uses: codecov/codecov-action@v3
|
||||||
with:
|
with:
|
||||||
|
16
codecov.yml
16
codecov.yml
@ -1,17 +1,17 @@
|
|||||||
coverage:
|
coverage:
|
||||||
status:
|
status:
|
||||||
# Project coverage status: require ≥ 90% overall, allow up to 2% drop
|
# Project coverage status: require >= 90% overall, allow up to 2% drop
|
||||||
project:
|
project:
|
||||||
default:
|
default:
|
||||||
target: 90% # overall coverage goal
|
target: 95% # overall coverage goal
|
||||||
threshold: 2% # tolerable drop below target before failing
|
threshold: 2% # tolerable drop below target before failing
|
||||||
base: auto # compare against default branch
|
base: auto # compare against default branch
|
||||||
only_pulls: true # enforce on PRs only
|
only_pulls: true # enforce on PRs only
|
||||||
|
|
||||||
# Patch coverage status: require ≥ 95% on new/changed lines, zero tolerance
|
# Patch coverage status: require >= 95% on new/changed lines, zero tolerance
|
||||||
patch:
|
patch:
|
||||||
default:
|
default:
|
||||||
target: 95% # coverage goal for new or changed code
|
target: 50% # coverage goal for new or changed code
|
||||||
threshold: 0% # no uncovered lines allowed
|
threshold: 0% # no uncovered lines allowed
|
||||||
base: auto # diff against default branch
|
base: auto # diff against default branch
|
||||||
only_pulls: true
|
only_pulls: true
|
||||||
|
@ -8,7 +8,7 @@ pub struct Matrix<T> {
|
|||||||
data: Vec<T>,
|
data: Vec<T>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> Matrix<T> {
|
impl<T: Clone> Matrix<T> {
|
||||||
/// Build from columns (each inner Vec is one column)
|
/// Build from columns (each inner Vec is one column)
|
||||||
pub fn from_cols(cols_data: Vec<Vec<T>>) -> Self {
|
pub fn from_cols(cols_data: Vec<Vec<T>>) -> Self {
|
||||||
let cols = cols_data.len();
|
let cols = cols_data.len();
|
||||||
@ -38,13 +38,22 @@ impl<T> Matrix<T> {
|
|||||||
Matrix { rows, cols, data }
|
Matrix { rows, cols, data }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn rows(&self) -> usize {
|
|
||||||
self.rows
|
|
||||||
}
|
|
||||||
pub fn data(&self) -> &[T] {
|
pub fn data(&self) -> &[T] {
|
||||||
&self.data
|
&self.data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn data_mut(&mut self) -> &mut [T] {
|
||||||
|
&mut self.data
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn as_vec(&self) -> Vec<T> {
|
||||||
|
self.data.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn rows(&self) -> usize {
|
||||||
|
self.rows
|
||||||
|
}
|
||||||
|
|
||||||
pub fn cols(&self) -> usize {
|
pub fn cols(&self) -> usize {
|
||||||
self.cols
|
self.cols
|
||||||
}
|
}
|
||||||
@ -184,6 +193,77 @@ impl<'a, T> MatrixRow<'a, T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Specifies the axis along which to perform a reduction operation.
|
||||||
|
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||||
|
pub enum Axis {
|
||||||
|
/// Apply reduction along columns (vertical axis).
|
||||||
|
Col,
|
||||||
|
/// Apply reduction along rows (horizontal axis).
|
||||||
|
Row,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A trait to turn either a `Matrix<T>` or a scalar T into a `Vec<T>` of
|
||||||
|
/// length `rows*cols` (broadcasting the scalar).
|
||||||
|
pub trait Broadcastable<T> {
|
||||||
|
fn to_vec(&self, rows: usize, cols: usize) -> Vec<T>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Clone> Broadcastable<T> for T {
|
||||||
|
fn to_vec(&self, rows: usize, cols: usize) -> Vec<T> {
|
||||||
|
vec![self.clone(); rows * cols]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Clone> Broadcastable<T> for Matrix<T> {
|
||||||
|
fn to_vec(&self, rows: usize, cols: usize) -> Vec<T> {
|
||||||
|
assert_eq!(self.rows, rows, "row count mismatch");
|
||||||
|
assert_eq!(self.cols, cols, "col count mismatch");
|
||||||
|
self.data.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates element-wise eq, lt, le, gt and ge methods
|
||||||
|
/// where the rhs can be a `Matrix<T>` or a scalar T.
|
||||||
|
macro_rules! impl_elementwise_cmp {
|
||||||
|
(
|
||||||
|
$( $method:ident => $op:tt ),* $(,)?
|
||||||
|
) => {
|
||||||
|
impl<T: PartialOrd + Clone> Matrix<T> {
|
||||||
|
$(
|
||||||
|
#[doc = concat!("Element-wise comparison `self ", stringify!($op), " rhs`,\n\
|
||||||
|
where `rhs` may be a `Matrix<T>` or a scalar T.")]
|
||||||
|
pub fn $method<Rhs>(&self, rhs: Rhs) -> BoolMatrix
|
||||||
|
where
|
||||||
|
Rhs: Broadcastable<T>,
|
||||||
|
{
|
||||||
|
// Prepare broadcasted rhs-data
|
||||||
|
let rhs_data = rhs.to_vec(self.rows, self.cols);
|
||||||
|
|
||||||
|
// Pairwise compare
|
||||||
|
let data = self
|
||||||
|
.data
|
||||||
|
.iter()
|
||||||
|
.cloned()
|
||||||
|
.zip(rhs_data.into_iter())
|
||||||
|
.map(|(a, b)| a $op b)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
BoolMatrix::from_vec(data, self.rows, self.cols)
|
||||||
|
}
|
||||||
|
)*
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Instantiate element-wise comparison implementations for matrices.
|
||||||
|
impl_elementwise_cmp! {
|
||||||
|
eq_elementwise => ==,
|
||||||
|
lt_elementwise => <,
|
||||||
|
le_elementwise => <=,
|
||||||
|
gt_elementwise => >,
|
||||||
|
ge_elementwise => >=,
|
||||||
|
}
|
||||||
|
|
||||||
/// Generates element-wise arithmetic implementations for matrices.
|
/// Generates element-wise arithmetic implementations for matrices.
|
||||||
macro_rules! impl_elementwise_op {
|
macro_rules! impl_elementwise_op {
|
||||||
($OpTrait:ident, $method:ident, $op:tt) => {
|
($OpTrait:ident, $method:ident, $op:tt) => {
|
||||||
@ -217,10 +297,29 @@ impl_elementwise_op!(Sub, sub, -);
|
|||||||
impl_elementwise_op!(Mul, mul, *);
|
impl_elementwise_op!(Mul, mul, *);
|
||||||
impl_elementwise_op!(Div, div, /);
|
impl_elementwise_op!(Div, div, /);
|
||||||
|
|
||||||
pub type FloatMatrix = Matrix<f64>;
|
/// Generates element-wise arithmetic implementations for matrices with scalars.
|
||||||
pub type BoolMatrix = Matrix<bool>;
|
macro_rules! impl_elementwise_op_scalar {
|
||||||
pub type IntMatrix = Matrix<i32>;
|
($OpTrait:ident, $method:ident, $op:tt) => {
|
||||||
pub type StringMatrix = Matrix<String>;
|
impl<'a, T> std::ops::$OpTrait<T> for &'a Matrix<T>
|
||||||
|
where
|
||||||
|
T: Clone + std::ops::$OpTrait<Output = T>,
|
||||||
|
{
|
||||||
|
type Output = Matrix<T>;
|
||||||
|
|
||||||
|
fn $method(self, rhs: T) -> Matrix<T> {
|
||||||
|
// Apply the operation element-wise and collect into a new matrix
|
||||||
|
let data = self.data.iter().cloned().map(|a| a $op rhs.clone()).collect();
|
||||||
|
Matrix { rows: self.rows, cols: self.cols, data }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Instantiate element-wise addition, subtraction, multiplication, and division
|
||||||
|
impl_elementwise_op_scalar!(Add, add, +);
|
||||||
|
impl_elementwise_op_scalar!(Sub, sub, -);
|
||||||
|
impl_elementwise_op_scalar!(Mul, mul, *);
|
||||||
|
impl_elementwise_op_scalar!(Div, div, /);
|
||||||
|
|
||||||
/// Generates element-wise bitwise operations for boolean matrices.
|
/// Generates element-wise bitwise operations for boolean matrices.
|
||||||
macro_rules! impl_bitwise_op {
|
macro_rules! impl_bitwise_op {
|
||||||
@ -265,17 +364,15 @@ impl Not for Matrix<bool> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Specifies the axis along which to perform a reduction operation.
|
pub type FloatMatrix = Matrix<f64>;
|
||||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
pub type BoolMatrix = Matrix<bool>;
|
||||||
pub enum Axis {
|
pub type IntMatrix = Matrix<i32>;
|
||||||
/// Apply reduction along columns (vertical axis).
|
pub type StringMatrix = Matrix<String>;
|
||||||
Col,
|
|
||||||
/// Apply reduction along rows (horizontal axis).
|
|
||||||
Row,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
use crate::matrix::BoolOps;
|
||||||
|
|
||||||
use super::{BoolMatrix, FloatMatrix, Matrix, StringMatrix};
|
use super::{BoolMatrix, FloatMatrix, Matrix, StringMatrix};
|
||||||
|
|
||||||
// Helper function to create a basic Matrix for testing
|
// Helper function to create a basic Matrix for testing
|
||||||
@ -541,6 +638,22 @@ mod tests {
|
|||||||
assert_eq!(rows[1], vec![2, 4, 6, 8]);
|
assert_eq!(rows[1], vec![2, 4, 6, 8]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// test data_mut
|
||||||
|
#[test]
|
||||||
|
fn test_data_mut() {
|
||||||
|
let mut matrix = create_test_matrix(); // 3x3
|
||||||
|
// 1 4 7
|
||||||
|
// 2 5 8
|
||||||
|
// 3 6 9
|
||||||
|
|
||||||
|
let data_mut = matrix.data_mut();
|
||||||
|
data_mut[0] = 10;
|
||||||
|
data_mut[1] = 20;
|
||||||
|
|
||||||
|
assert_eq!(matrix[(0, 0)], 10);
|
||||||
|
assert_eq!(matrix[(1, 0)], 20);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_matrix_row_get_and_iter() {
|
fn test_matrix_row_get_and_iter() {
|
||||||
let matrix = create_test_matrix_2x4(); // 2x4
|
let matrix = create_test_matrix_2x4(); // 2x4
|
||||||
@ -1161,8 +1274,7 @@ mod tests {
|
|||||||
assert!((div[(0, 1)] - 3.0 / 2.5).abs() < 1e-9); // 1.2
|
assert!((div[(0, 1)] - 3.0 / 2.5).abs() < 1e-9); // 1.2
|
||||||
assert!((div[(1, 1)] - 4.0 / 3.5).abs() < 1e-9); // 1.14...
|
assert!((div[(1, 1)] - 4.0 / 3.5).abs() < 1e-9); // 1.14...
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
fn create_test_matrix_i32() -> Matrix<i32> {
|
fn create_test_matrix_i32() -> Matrix<i32> {
|
||||||
Matrix::from_cols(vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]])
|
Matrix::from_cols(vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]])
|
||||||
}
|
}
|
||||||
@ -1232,5 +1344,52 @@ mod tests {
|
|||||||
assert_eq!(matrix2.column(1), initial_col0_data_m2.as_slice());
|
assert_eq!(matrix2.column(1), initial_col0_data_m2.as_slice());
|
||||||
assert_eq!(matrix2.data(), &[4, 5, 6, 1, 2, 3, 7, 8, 9]);
|
assert_eq!(matrix2.data(), &[4, 5, 6, 1, 2, 3, 7, 8, 9]);
|
||||||
}
|
}
|
||||||
// Axis enum doesn't have logic, no tests needed directly, but its presence is verified by compilation.
|
|
||||||
|
// Test broadcastable operations
|
||||||
|
#[test]
|
||||||
|
fn test_comparision_broadcast() {
|
||||||
|
let matrix = create_test_matrix();
|
||||||
|
// test all > 0
|
||||||
|
let result = matrix.gt_elementwise(0).as_vec();
|
||||||
|
let expected = vec![true; result.len()];
|
||||||
|
assert_eq!(result, expected);
|
||||||
|
|
||||||
|
let ma = create_test_matrix();
|
||||||
|
let mb = create_test_matrix();
|
||||||
|
|
||||||
|
let result = ma.eq_elementwise(mb);
|
||||||
|
assert!(result.all());
|
||||||
|
|
||||||
|
let result = matrix.lt_elementwise(1e10 as i32).all();
|
||||||
|
assert!(result);
|
||||||
|
|
||||||
|
for i in 0..matrix.rows() {
|
||||||
|
for j in 0..matrix.cols() {
|
||||||
|
let vx = matrix[(i, j)];
|
||||||
|
let c = &(matrix.le_elementwise(vx)) & &(matrix.ge_elementwise(vx));
|
||||||
|
assert_eq!(c.count(), 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_arithmetic_broadcast() {
|
||||||
|
let matrix = create_test_matrix();
|
||||||
|
let result = &matrix + 1;
|
||||||
|
for i in 0..matrix.rows() {
|
||||||
|
for j in 0..matrix.cols() {
|
||||||
|
assert_eq!(result[(i, j)], matrix[(i, j)] + 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// test mul and div
|
||||||
|
let result = &matrix * 2;
|
||||||
|
let result2 = &matrix / 2;
|
||||||
|
for i in 0..matrix.rows() {
|
||||||
|
for j in 0..matrix.cols() {
|
||||||
|
assert_eq!(result[(i, j)], matrix[(i, j)] * 2);
|
||||||
|
assert_eq!(result2[(i, j)], matrix[(i, j)] / 2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user