2025-07-12 00:55:44 +01:00

2085 lines
68 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! A simple column-major Matrix implementation with element-wise operations.
use std::ops::{Add, BitAnd, BitOr, BitXor, Div, Index, IndexMut, Mul, Not, Sub};
/// A columnmajor 2D matrix of `T`. Index as `Array(row, column)`.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Matrix<T> {
rows: usize,
cols: usize,
data: Vec<T>,
}
impl<T: Clone> Matrix<T> {
/// Build from columns (each inner Vec is one column)
pub fn from_cols(cols_data: Vec<Vec<T>>) -> Self {
let cols = cols_data.len();
assert!(cols > 0, "need at least one column");
// Handle empty cols_data
let rows = cols_data.get(0).map_or(0, |c| c.len());
// Allow 0-row matrices if columns are empty, but not 0-col matrices if rows > 0
assert!(
rows > 0 || cols == 0,
"need at least one row if columns exist"
);
for (i, col) in cols_data.iter().enumerate() {
assert!(
col.len() == rows,
"col {} has len {}, expected {}",
i,
col.len(),
rows
);
}
// Flatten column data directly
let data = cols_data.into_iter().flatten().collect();
Matrix { rows, cols, data }
}
/// Build from a flat Vec, assuming column-major order.
pub fn from_vec(data: Vec<T>, rows: usize, cols: usize) -> Self {
assert!(
rows > 0 || cols == 0,
"need at least one row if columns exist"
);
assert!(
cols > 0 || rows == 0,
"need at least one column if rows exist"
);
if rows * cols != 0 {
// Only assert length if matrix is non-empty
assert_eq!(
data.len(),
rows * cols,
"data length mismatch: expected {}, got {}",
rows * cols,
data.len()
);
} else {
assert!(data.is_empty(), "data must be empty for 0-sized matrix");
}
Matrix { rows, cols, data }
}
/// Build from a flat Vec, assuming row-major order.
pub fn from_rows_vec(data: Vec<T>, rows: usize, cols: usize) -> Self {
let mut new_vec = Vec::with_capacity(rows * cols);
for c in 0..cols {
for r in 0..rows {
new_vec.push(data[r * cols + c].clone());
}
}
Matrix::from_vec(new_vec, rows, cols)
}
pub fn data(&self) -> &[T] {
&self.data
}
pub fn data_mut(&mut self) -> &mut [T] {
&mut self.data
}
/// Consumes the Matrix and returns its underlying data Vec.
pub fn into_vec(self) -> Vec<T> {
self.data
}
/// Creates a new `Vec<T>` containing the matrix data (cloned).
pub fn to_vec(&self) -> Vec<T> {
self.data.clone()
}
pub fn rows(&self) -> usize {
self.rows
}
pub fn cols(&self) -> usize {
self.cols
}
pub fn shape(&self) -> (usize, usize) {
(self.rows, self.cols)
}
/// Get element reference (immutable). Panics on out-of-bounds.
pub fn get(&self, r: usize, c: usize) -> &T {
&self[(r, c)]
}
/// Get element reference (mutable). Panics on out-of-bounds.
pub fn get_mut(&mut self, r: usize, c: usize) -> &mut T {
&mut self[(r, c)]
}
#[inline]
pub fn column(&self, c: usize) -> &[T] {
assert!(
c < self.cols,
"column index {} out of bounds for {} columns",
c,
self.cols
);
let start = c * self.rows;
&self.data[start..start + self.rows]
}
#[inline]
pub fn column_mut(&mut self, c: usize) -> &mut [T] {
assert!(
c < self.cols,
"column index {} out of bounds for {} columns",
c,
self.cols
);
let start = c * self.rows;
&mut self.data[start..start + self.rows]
}
pub fn iter_columns(&self) -> impl Iterator<Item = &[T]> {
(0..self.cols).map(move |c| self.column(c))
}
pub fn iter_rows(&self) -> impl Iterator<Item = MatrixRow<'_, T>> {
(0..self.rows).map(move |r| MatrixRow {
matrix: self,
row: r,
})
}
/// Swaps two columns in the matrix. Panics on out-of-bounds.
pub fn swap_columns(&mut self, c1: usize, c2: usize) {
assert!(
c1 < self.cols,
"column index c1={} out of bounds for {} columns",
c1,
self.cols
);
assert!(
c2 < self.cols,
"column index c2={} out of bounds for {} columns",
c2,
self.cols
);
if c1 == c2 || self.rows == 0 || self.cols == 0 {
return;
}
let (start1, end1) = (c1 * self.rows, (c1 + 1) * self.rows);
let (start2, end2) = (c2 * self.rows, (c2 + 1) * self.rows);
if (start1 < start2 && end1 > start2) || (start2 < start1 && end2 > start1) {
panic!("Cannot swap overlapping columns");
}
// element-wise swap
for r in 0..self.rows {
self.data.swap(start1 + r, start2 + r);
}
}
/// Deletes a column from the matrix. Panics on out-of-bounds.
/// This is O(N) where N is the number of elements.
pub fn delete_column(&mut self, col: usize) {
assert!(
col < self.cols,
"column index {} out of bounds for {} columns",
col,
self.cols
);
let start = col * self.rows;
self.data.drain(start..start + self.rows); // Efficient removal
self.cols -= 1;
}
#[inline]
pub fn row(&self, r: usize) -> Vec<T> {
assert!(
r < self.rows,
"row index {} out of bounds for {} rows",
r,
self.rows
);
let mut row_data = Vec::with_capacity(self.cols);
for c in 0..self.cols {
row_data.push(self[(r, c)].clone()); // Clone each element
}
row_data
}
pub fn row_copy_from_slice(&mut self, r: usize, values: &[T]) {
assert!(
r < self.rows,
"row index {} out of bounds for {} rows",
r,
self.rows
);
assert!(
values.len() == self.cols,
"input slice length {} does not match number of columns {}",
values.len(),
self.cols
);
for (c, value) in values.iter().enumerate() {
let idx = r + c * self.rows; // column-major index
self.data[idx] = value.clone();
}
}
/// Deletes a row from the matrix. Panics on out-of-bounds.
/// This is O(N) where N is the number of elements, as it rebuilds the data vec.
pub fn delete_row(&mut self, row: usize) {
assert!(
row < self.rows,
"row index {} out of bounds for {} rows",
row,
self.rows
);
if self.rows == 0 {
return;
} // Nothing to delete
let old_rows = self.rows;
let new_rows = self.rows - 1;
let mut new_data = Vec::with_capacity(new_rows * self.cols);
for c in 0..self.cols {
let col_start_old = c * old_rows;
for r in 0..old_rows {
if r != row {
// Must clone as we are reading from the old data while building the new one
new_data.push(self.data[col_start_old + r].clone());
}
}
}
self.data = new_data;
self.rows = new_rows;
}
pub fn transpose(&self) -> Matrix<T> {
let (m, n) = (self.rows, self.cols);
let mut transposed_data = Vec::with_capacity(m * n);
// In the transposed matrix the old rows become the new columns.
for j in 0..m {
// new column index = old row index
for i in 0..n {
// new row index = old col index
transposed_data.push(self[(j, i)].clone()); // A(T)[i,j] = A[j,i]
}
}
Matrix::from_vec(transposed_data, n, m) // size is n × m
}
}
impl<T: Clone> Matrix<T> {
/// Adds a column to the matrix at the specified index. Panics if index > cols or length mismatch.
/// This is O(N) where N is the number of elements.
pub fn add_column(&mut self, index: usize, column: Vec<T>) {
assert!(
index <= self.cols,
"add_column index {} out of bounds for {} columns",
index,
self.cols
);
assert_eq!(
column.len(),
self.rows,
"column length mismatch: expected {}, got {}",
self.rows,
column.len()
);
if self.rows == 0 && self.cols == 0 {
// Special case: adding first column to empty matrix
assert!(index == 0, "index must be 0 for adding first column");
self.data = column;
self.cols = 1;
// self.rows should be correctly set by column.len() assertion
} else {
let insert_pos = index * self.rows;
self.data.splice(insert_pos..insert_pos, column); // Efficient insertion
self.cols += 1;
}
}
/// Adds a row to the matrix at the specified index. Panics if index > rows or length mismatch.
/// This is O(N) where N is the number of elements, as it rebuilds the data vec.
pub fn add_row(&mut self, index: usize, row: Vec<T>) {
assert!(
index <= self.rows,
"add_row index {} out of bounds for {} rows",
index,
self.rows
);
assert_eq!(
row.len(),
self.cols,
"row length mismatch: expected {} (cols), got {}",
self.cols,
row.len()
);
if self.cols == 0 && self.rows == 0 {
// Special case: adding first row to empty matrix
assert!(index == 0, "index must be 0 for adding first row");
// Cannot add a row if there are no columns yet. Maybe panic or change API?
assert!(
self.cols > 0 || row.is_empty(),
"cannot add non-empty row to matrix with 0 columns"
);
if row.is_empty() {
return;
} // Adding empty row to empty matrix is no-op
}
let old_rows = self.rows;
let new_rows = self.rows + 1;
let mut new_data = Vec::with_capacity(new_rows * self.cols);
let mut row_iter = row.into_iter(); // Consume the input row vec
for c in 0..self.cols {
let old_col_start = c * old_rows;
for r in 0..new_rows {
if r == index {
// Take the next element from the provided row vector
new_data.push(row_iter.next().expect("Row iterator exhausted prematurely - should have been caught by length assert"));
} else {
// Calculate the corresponding old row index
let old_r = if r < index { r } else { r - 1 };
// Must clone as we are reading from the old data while building the new one
new_data.push(self.data[old_col_start + old_r].clone());
}
}
}
self.data = new_data;
self.rows = new_rows;
}
/// Return a new matrix where row 0 of `self` is repeated `n` times.
pub fn repeat_rows(&self, n: usize) -> Matrix<T>
where
T: Clone,
{
let mut data = Vec::with_capacity(n * self.cols());
let zeroth_row = self.row(0);
for value in &zeroth_row {
for _ in 0..n {
data.push(value.clone()); // Clone each element
}
}
Matrix::from_vec(data, n, self.cols)
}
/// Creates a new matrix filled with a specific value of the specified size.
pub fn filled(rows: usize, cols: usize, value: T) -> Self {
Matrix {
rows,
cols,
data: vec![value; rows * cols], // Fill with the specified value
}
}
/// Creates a new matrix by broadcasting a 1-row matrix to a target shape.
/// Panics if `self` is not a 1-row matrix or if `self.cols()` does not match `target_cols`.
pub fn broadcast_row_to_target_shape(
&self,
target_rows: usize,
target_cols: usize,
) -> Matrix<T> {
assert_eq!(
self.rows(),
1,
"broadcast_row_to_target_shape can only be called on a 1-row matrix."
);
assert_eq!(
self.cols(),
target_cols,
"Column count mismatch for broadcasting: source has {} columns, target has {} columns.",
self.cols(),
target_cols
);
let mut data = Vec::with_capacity(target_rows * target_cols);
let original_row_data = self.row(0); // Get the single row data
for _ in 0..target_rows {
// Repeat 'target_rows' times
for value in &original_row_data {
// Iterate over elements of the row
data.push(value.clone());
}
}
// The data is now in row-major order for the new matrix.
// We need to convert it to column-major for Matrix::from_vec.
Matrix::from_rows_vec(data, target_rows, target_cols)
}
}
impl Matrix<f64> {
/// Creates a new matrix filled with zeros of the specified size.
pub fn zeros(rows: usize, cols: usize) -> Self {
Matrix::filled(rows, cols, 0.0)
}
/// Creates a new matrix filled with ones of the specified size.
pub fn ones(rows: usize, cols: usize) -> Self {
Matrix::filled(rows, cols, 1.0)
}
/// Creates a new matrix filled with NaN values of the specified size.
pub fn nan(rows: usize, cols: usize) -> Matrix<f64> {
Matrix::filled(rows, cols, f64::NAN)
}
}
impl<T> Index<(usize, usize)> for Matrix<T> {
type Output = T;
#[inline]
fn index(&self, (r, c): (usize, usize)) -> &T {
// Validate that the requested indices are within bounds
assert!(
r < self.rows && c < self.cols,
"index out of bounds: ({}, {}) vs {}x{}",
r,
c,
self.rows,
self.cols
);
// Compute column-major offset and return reference
&self.data[c * self.rows + r]
}
}
impl<T> IndexMut<(usize, usize)> for Matrix<T> {
#[inline]
fn index_mut(&mut self, (r, c): (usize, usize)) -> &mut T {
// Validate that the requested indices are within bounds
assert!(
r < self.rows && c < self.cols,
"index out of bounds: ({}, {}) vs {}x{}",
r,
c,
self.rows,
self.cols
);
// Compute column-major offset and return mutable reference
&mut self.data[c * self.rows + r]
}
}
// --- Row Iterator Helper ---
/// Represents an immutable view of a single row in the matrix.
pub struct MatrixRow<'a, T> {
matrix: &'a Matrix<T>,
row: usize,
}
impl<'a, T> MatrixRow<'a, T> {
/// Returns a reference to the element at the given column in this row.
pub fn get(&self, c: usize) -> &T {
&self.matrix[(self.row, c)]
}
/// Returns an iterator over all elements in this row.
pub fn iter(&self) -> impl Iterator<Item = &T> {
(0..self.matrix.cols).map(move |c| self.get(c))
}
}
// --- Reduction Axis Enum ---
/// Specifies the axis along which to perform a reduction operation.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Axis {
/// Apply reduction along columns (vertical axis). Result has 1 row.
Col,
/// Apply reduction along rows (horizontal axis). Result has 1 column.
Row,
}
// --- Broadcasting ---
/// A trait to turn either a `Matrix<T>` or a scalar T into a `Vec<T>` of
/// length `rows*cols` (broadcasting the scalar). Used for comparisons.
pub trait Broadcastable<T> {
fn to_vec(&self, rows: usize, cols: usize) -> Vec<T>;
}
impl<T: Clone> Broadcastable<T> for T {
fn to_vec(&self, rows: usize, cols: usize) -> Vec<T> {
vec![self.clone(); rows * cols]
}
}
impl<T: Clone> Broadcastable<T> for Matrix<T> {
fn to_vec(&self, rows: usize, cols: usize) -> Vec<T> {
assert_eq!(self.rows, rows, "row count mismatch in broadcast");
assert_eq!(self.cols, cols, "col count mismatch in broadcast");
self.data.clone() // Clone the data for the broadcasted vec
}
}
// --- Element-wise Comparisons ---
/// Generates element-wise eq, lt, le, gt and ge methods
/// where the rhs can be a `Matrix<T>` or a scalar T.
macro_rules! impl_elementwise_cmp {
(
$( $method:ident => $op:tt ),* $(,)?
) => {
impl<T: PartialOrd + Clone> Matrix<T> {
$(
#[doc = concat!("Element-wise comparison `self ", stringify!($op), " rhs`,\n\
where `rhs` may be a `Matrix<T>` or a scalar T.\n\
Returns a `BoolMatrix`.")]
pub fn $method<Rhs>(&self, rhs: Rhs) -> BoolMatrix
where
Rhs: Broadcastable<T>,
{
// Prepare broadcasted rhs-data using the trait
let rhs_data = rhs.to_vec(self.rows, self.cols);
// Pairwise compare using iterators
let data = self
.data
.iter() // Borrow self's data
.zip(rhs_data.iter()) // Borrow rhs's broadcasted data
.map(|(a, b)| a $op b) // Perform comparison op
.collect();
// Create BoolMatrix from result
Matrix::<bool>::from_vec(data, self.rows, self.cols)
}
)*
}
};
}
// Instantiate element-wise comparison implementations for matrices.
impl_elementwise_cmp! {
eq_elem => ==,
ne_elem => !=,
lt_elem => <,
le_elem => <=,
gt_elem => >,
ge_elem => >=,
}
// --- Element-wise Arithmetic Operations (Macros generating all ownership variants) ---
fn check_matrix_dims_for_ops<T>(lhs: &Matrix<T>, rhs: &Matrix<T>) {
assert!(
lhs.rows == rhs.rows,
"Row count mismatch: left has {} rows, right has {} rows",
lhs.rows,
rhs.rows
);
assert!(
lhs.cols == rhs.cols,
"Column count mismatch: left has {} columns, right has {} columns",
lhs.cols,
rhs.cols
);
}
/// Generates element-wise arithmetic implementations for Matrix + Matrix
macro_rules! impl_elementwise_op_matrix_all {
($OpTrait:ident, $method:ident, $op:tt) => {
// &Matrix + &Matrix
impl<'a, 'b, T> $OpTrait<&'b Matrix<T>> for &'a Matrix<T>
where T: Clone + $OpTrait<Output = T> {
type Output = Matrix<T>;
fn $method(self, rhs: &'b Matrix<T>) -> Matrix<T> {
check_matrix_dims_for_ops(self, rhs);
let data = self.data.iter().cloned().zip(rhs.data.iter().cloned()).map(|(a, b)| a $op b).collect();
Matrix { rows: self.rows, cols: self.cols, data }
}
}
// Matrix + &Matrix (Consumes self)
impl<'b, T> $OpTrait<&'b Matrix<T>> for Matrix<T>
where T: Clone + $OpTrait<Output = T> {
type Output = Matrix<T>;
fn $method(mut self, rhs: &'b Matrix<T>) -> Matrix<T> { // Make self mutable for potential in-place modification
check_matrix_dims_for_ops(&self, rhs);
// Modify data in place
for (a, b) in self.data.iter_mut().zip(rhs.data.iter().cloned()) {
*a = a.clone() $op b; // Requires T: Clone for the *a = part
}
// Return modified self (its data vec was consumed conceptually)
self
// Alternative: Collect into new Vec if T is not Clone or in-place is complex
// let data = self.data.into_iter().zip(rhs.data.iter().cloned()).map(|(a, b)| a $op b).collect();
// Matrix { rows: self.rows, cols: self.cols, data }
}
}
// &Matrix + Matrix (Consumes rhs)
impl<'a, T> $OpTrait<Matrix<T>> for &'a Matrix<T>
where T: Clone + $OpTrait<Output = T> {
type Output = Matrix<T>;
fn $method(self, mut rhs: Matrix<T>) -> Matrix<T> { // Make rhs mutable
check_matrix_dims_for_ops(self, &rhs);
// Modify rhs data in place
for (a, b) in self.data.iter().cloned().zip(rhs.data.iter_mut()) {
*b = a $op b.clone(); // Requires T: Clone for the *b = part
}
// Return modified rhs
rhs
// Alternative: Collect into new Vec
// let data = self.data.iter().cloned().zip(rhs.data.into_iter()).map(|(a, b)| a $op b).collect();
// Matrix { rows: self.rows, cols: self.cols, data }
}
}
// Matrix + Matrix (Consumes both)
impl<T> $OpTrait<Matrix<T>> for Matrix<T>
where T: Clone + $OpTrait<Output = T> {
type Output = Matrix<T>;
fn $method(mut self, rhs: Matrix<T>) -> Matrix<T> { // Make self mutable
check_matrix_dims_for_ops(&self, &rhs);
// Modify self data in place
for (a, b) in self.data.iter_mut().zip(rhs.data.into_iter()) {
*a = a.clone() $op b; // Requires T: Clone
}
// Return modified self
self
// Alternative: Collect into new Vec
// let data = self.data.into_iter().zip(rhs.data.into_iter()).map(|(a, b)| a $op b).collect();
// Matrix { rows: self.rows, cols: self.cols, data }
}
}
};
}
/// Generates element-wise arithmetic implementations for Matrix + Scalar
macro_rules! impl_elementwise_op_scalar_all {
($OpTrait:ident, $method:ident, $op:tt) => {
// &Matrix + Scalar
impl<'a, T> $OpTrait<T> for &'a Matrix<T>
where T: Clone + $OpTrait<Output = T> {
type Output = Matrix<T>;
fn $method(self, rhs: T) -> Matrix<T> {
let data = self.data.iter().cloned().map(|a| a $op rhs.clone()).collect();
Matrix { rows: self.rows, cols: self.cols, data }
}
}
// Matrix + Scalar (Consumes self)
impl<T> $OpTrait<T> for Matrix<T>
where T: Clone + $OpTrait<Output = T> {
type Output = Matrix<T>;
fn $method(mut self, rhs: T) -> Matrix<T> { // Make self mutable
// Modify self data in place
for a in self.data.iter_mut() {
*a = a.clone() $op rhs.clone(); // Requires T: Clone
}
// Return modified self
self
// Alternative: Collect into new Vec
// let data = self.data.into_iter().map(|a| a $op rhs.clone()).collect();
// Matrix { rows: self.rows, cols: self.cols, data }
}
}
// NOTE: Scalar + Matrix (e.g., 1.0 + matrix) is NOT implemented here.
// It would require `impl Add<Matrix<T>> for T`, which is discouraged
// for primitive types unless inside the crate defining T.
};
}
// Instantiate ALL combinations for arithmetic ops using the new macros
impl_elementwise_op_matrix_all!(Add, add, +);
impl_elementwise_op_matrix_all!(Sub, sub, -);
impl_elementwise_op_matrix_all!(Mul, mul, *); // Element-wise multiplication
impl_elementwise_op_matrix_all!(Div, div, /); // Element-wise division
impl_elementwise_op_scalar_all!(Add, add, +);
impl_elementwise_op_scalar_all!(Sub, sub, -);
impl_elementwise_op_scalar_all!(Mul, mul, *);
impl_elementwise_op_scalar_all!(Div, div, /);
// --- Element-wise Bitwise Operations (BoolMatrix) ---
macro_rules! impl_bitwise_op_all {
($OpTrait:ident, $method:ident, $op:tt) => {
// &Matrix<bool> OP &Matrix<bool>
impl<'a, 'b> $OpTrait<&'b Matrix<bool>> for &'a Matrix<bool> {
type Output = Matrix<bool>;
fn $method(self, rhs: &'b Matrix<bool>) -> Matrix<bool> {
check_matrix_dims_for_ops(self, rhs);
let data = self.data.iter().cloned().zip(rhs.data.iter().cloned()).map(|(a, b)| a $op b).collect();
Matrix { rows: self.rows, cols: self.cols, data }
}
}
// Matrix<bool> OP &Matrix<bool>
impl<'b> $OpTrait<&'b Matrix<bool>> for Matrix<bool> {
type Output = Matrix<bool>;
fn $method(mut self, rhs: &'b Matrix<bool>) -> Matrix<bool> {
check_matrix_dims_for_ops(&self, rhs);
for (a, b) in self.data.iter_mut().zip(rhs.data.iter()) { *a = *a $op *b; } // bool is Copy
self
}
}
// &Matrix<bool> OP Matrix<bool>
impl<'a> $OpTrait<Matrix<bool>> for &'a Matrix<bool> {
type Output = Matrix<bool>;
fn $method(self, mut rhs: Matrix<bool>) -> Matrix<bool> {
check_matrix_dims_for_ops(self, &rhs);
for (a, b) in self.data.iter().zip(rhs.data.iter_mut()) { *b = *a $op *b; } // bool is Copy
rhs
}
}
// Matrix<bool> OP Matrix<bool>
impl $OpTrait<Matrix<bool>> for Matrix<bool> {
type Output = Matrix<bool>;
fn $method(mut self, rhs: Matrix<bool>) -> Matrix<bool> {
check_matrix_dims_for_ops(&self, &rhs);
for (a, b) in self.data.iter_mut().zip(rhs.data.iter()) { *a = *a $op *b; } // bool is Copy
self
}
}
};
}
// Instantiate ALL combinations for bitwise ops
impl_bitwise_op_all!(BitAnd, bitand, &);
impl_bitwise_op_all!(BitOr, bitor, |);
impl_bitwise_op_all!(BitXor, bitxor, ^);
// --- Logical Not ---
// `!Matrix<bool>` (consumes matrix)
impl Not for Matrix<bool> {
type Output = Matrix<bool>;
fn not(mut self) -> Matrix<bool> {
// Take by value, make mutable
for val in self.data.iter_mut() {
*val = !*val; // Invert in place
}
self // Return the modified matrix
}
}
// `!&Matrix<bool>` (borrows matrix, returns new matrix)
impl Not for &Matrix<bool> {
type Output = Matrix<bool>;
fn not(self) -> Matrix<bool> {
// Take by reference
let data = self.data.iter().map(|&v| !v).collect(); // Create new data vec
Matrix {
rows: self.rows,
cols: self.cols,
data,
}
}
}
// --- Type Aliases ---
pub type FloatMatrix = Matrix<f64>;
pub type BoolMatrix = Matrix<bool>;
pub type IntMatrix = Matrix<i32>;
pub type StringMatrix = Matrix<String>;
// --- Unit Tests ---
#[cfg(test)]
mod tests {
use crate::matrix::BoolOps;
use super::*; // Import items from outer scope
// Helper to create a 2x2 f64 matrix easily
fn make_f64_matrix(a: f64, b: f64, c: f64, d: f64) -> FloatMatrix {
Matrix::from_cols(vec![vec![a, c], vec![b, d]])
}
// Helper to create a 2x2 bool matrix easily
fn make_bool_matrix(a: bool, b: bool, c: bool, d: bool) -> BoolMatrix {
Matrix::from_cols(vec![vec![a, c], vec![b, d]])
}
// --- Arithmetic Tests ---
#[test]
fn test_add_f64() {
let m1 = make_f64_matrix(1.0, 2.0, 3.0, 4.0);
let m2 = make_f64_matrix(5.0, 6.0, 7.0, 8.0);
let expected = make_f64_matrix(6.0, 8.0, 10.0, 12.0);
assert_eq!(m1.clone() + m2.clone(), expected, "M + M");
assert_eq!(m1.clone() + &m2, expected, "M + &M");
assert_eq!(&m1 + m2.clone(), expected, "&M + M");
assert_eq!(&m1 + &m2, expected, "&M + &M");
}
#[test]
fn test_add_scalar_f64() {
let m1 = make_f64_matrix(1.0, 2.0, 3.0, 4.0);
let scalar = 10.0;
let expected = make_f64_matrix(11.0, 12.0, 13.0, 14.0);
assert_eq!(m1.clone() + scalar, expected, "M + S");
assert_eq!(&m1 + scalar, expected, "&M + S");
}
#[test]
fn test_sub_f64() {
let m1 = make_f64_matrix(10.0, 20.0, 30.0, 40.0);
let m2 = make_f64_matrix(1.0, 2.0, 3.0, 4.0);
let expected = make_f64_matrix(9.0, 18.0, 27.0, 36.0);
assert_eq!(m1.clone() - m2.clone(), expected, "M - M");
assert_eq!(m1.clone() - &m2, expected, "M - &M");
assert_eq!(&m1 - m2.clone(), expected, "&M - M");
assert_eq!(&m1 - &m2, expected, "&M - &M");
}
#[test]
fn test_sub_scalar_f64() {
let m1 = make_f64_matrix(11.0, 12.0, 13.0, 14.0);
let scalar = 10.0;
let expected = make_f64_matrix(1.0, 2.0, 3.0, 4.0);
assert_eq!(m1.clone() - scalar, expected, "M - S");
assert_eq!(&m1 - scalar, expected, "&M - S");
}
#[test]
fn test_mul_f64() {
// Element-wise
let m1 = make_f64_matrix(1.0, 2.0, 3.0, 4.0);
let m2 = make_f64_matrix(5.0, 6.0, 7.0, 8.0);
let expected = make_f64_matrix(5.0, 12.0, 21.0, 32.0);
assert_eq!(m1.clone() * m2.clone(), expected, "M * M");
assert_eq!(m1.clone() * &m2, expected, "M * &M");
assert_eq!(&m1 * m2.clone(), expected, "&M * M");
assert_eq!(&m1 * &m2, expected, "&M * &M");
}
#[test]
fn test_mul_scalar_f64() {
let m1 = make_f64_matrix(1.0, 2.0, 3.0, 4.0);
let scalar = 3.0;
let expected = make_f64_matrix(3.0, 6.0, 9.0, 12.0);
assert_eq!(m1.clone() * scalar, expected, "M * S");
assert_eq!(&m1 * scalar, expected, "&M * S");
}
#[test]
fn test_div_f64() {
// Element-wise
let m1 = make_f64_matrix(10.0, 20.0, 30.0, 40.0);
let m2 = make_f64_matrix(2.0, 5.0, 6.0, 8.0);
let expected = make_f64_matrix(5.0, 4.0, 5.0, 5.0);
assert_eq!(m1.clone() / m2.clone(), expected, "M / M");
assert_eq!(m1.clone() / &m2, expected, "M / &M");
assert_eq!(&m1 / m2.clone(), expected, "&M / M");
assert_eq!(&m1 / &m2, expected, "&M / &M");
}
#[test]
fn test_div_scalar_f64() {
let m1 = make_f64_matrix(10.0, 20.0, 30.0, 40.0);
let scalar = 10.0;
let expected = make_f64_matrix(1.0, 2.0, 3.0, 4.0);
assert_eq!(m1.clone() / scalar, expected, "M / S");
assert_eq!(&m1 / scalar, expected, "&M / S");
}
#[test]
fn test_chained_ops_f64() {
let m = make_f64_matrix(1.0, 2.0, 3.0, 4.0);
let result = (((m.clone() + 1.0) * 2.0) - 4.0) / 2.0;
// Expected:
// m+1 = [2, 3], [4, 5]
// *2 = [4, 6], [8, 10]
// -4 = [0, 2], [4, 6]
// /2 = [0, 1], [2, 3]
let expected = make_f64_matrix(0.0, 1.0, 2.0, 3.0);
assert_eq!(result, expected);
}
// --- Boolean Logic Tests ---
#[test]
fn test_bitand_bool() {
let m1 = make_bool_matrix(true, false, true, false);
let m2 = make_bool_matrix(true, true, false, false);
let expected = make_bool_matrix(true, false, false, false);
assert_eq!(m1.clone() & m2.clone(), expected, "M & M");
assert_eq!(m1.clone() & &m2, expected, "M & &M");
assert_eq!(&m1 & m2.clone(), expected, "&M & M");
assert_eq!(&m1 & &m2, expected, "&M & &M");
}
#[test]
fn test_bitor_bool() {
let m1 = make_bool_matrix(true, false, true, false);
let m2 = make_bool_matrix(true, true, false, false);
let expected = make_bool_matrix(true, true, true, false);
assert_eq!(m1.clone() | m2.clone(), expected, "M | M");
assert_eq!(m1.clone() | &m2, expected, "M | &M");
assert_eq!(&m1 | m2.clone(), expected, "&M | M");
assert_eq!(&m1 | &m2, expected, "&M | &M");
}
#[test]
fn test_bitxor_bool() {
let m1 = make_bool_matrix(true, false, true, false);
let m2 = make_bool_matrix(true, true, false, false);
let expected = make_bool_matrix(false, true, true, false);
assert_eq!(m1.clone() ^ m2.clone(), expected, "M ^ M");
assert_eq!(m1.clone() ^ &m2, expected, "M ^ &M");
assert_eq!(&m1 ^ m2.clone(), expected, "&M ^ M");
assert_eq!(&m1 ^ &m2, expected, "&M ^ &M");
}
#[test]
fn test_not_bool() {
let m = make_bool_matrix(true, false, true, false);
let expected = make_bool_matrix(false, true, false, true);
assert_eq!(!m.clone(), expected, "!M (consuming)");
assert_eq!(!&m, expected, "!&M (borrowing)");
// Check original is unchanged when using !&M
let original = make_bool_matrix(true, false, true, false);
let _negated_ref = !&original;
assert_eq!(original, make_bool_matrix(true, false, true, false));
}
// --- Comparison Tests ---
#[test]
fn test_comparison_eq_elem() {
let m1 = make_f64_matrix(1.0, 2.0, 3.0, 4.0);
let m2 = make_f64_matrix(1.0, 0.0, 3.0, 5.0);
let s = 3.0;
let expected_m = make_bool_matrix(true, false, true, false);
let expected_s = make_bool_matrix(false, false, true, false);
assert_eq!(m1.eq_elem(m2), expected_m, "eq_elem matrix");
assert_eq!(m1.eq_elem(s), expected_s, "eq_elem scalar");
}
#[test]
fn test_comparison_gt_elem() {
let m1 = make_f64_matrix(1.0, 2.0, 3.0, 4.0);
let m2 = make_f64_matrix(0.0, 3.0, 3.0, 5.0);
let s = 2.5;
let expected_m = make_bool_matrix(true, false, false, false);
let expected_s = make_bool_matrix(false, false, true, true);
assert_eq!(m1.gt_elem(m2), expected_m, "gt_elem matrix");
assert_eq!(m1.gt_elem(s), expected_s, "gt_elem scalar");
}
// Add more comparison tests (lt, le, ge, ne) if desired...
// --- Basic Method Tests ---
#[test]
fn test_indexing() {
let m = make_f64_matrix(1.0, 2.0, 3.0, 4.0);
assert_eq!(m[(0, 0)], 1.0);
assert_eq!(m[(0, 1)], 2.0);
assert_eq!(m[(1, 0)], 3.0);
assert_eq!(m[(1, 1)], 4.0);
assert_eq!(*m.get(1, 0), 3.0); // Test get() too
}
#[test]
#[should_panic]
fn test_index_out_of_bounds_row() {
let m = make_f64_matrix(1.0, 2.0, 3.0, 4.0);
let _ = m[(2, 0)]; // Row 2 is out of bounds
}
#[test]
#[should_panic]
fn test_index_out_of_bounds_col() {
let m = make_f64_matrix(1.0, 2.0, 3.0, 4.0);
let _ = m[(0, 2)]; // Col 2 is out of bounds
}
#[test]
fn test_dimensions() {
let m = make_f64_matrix(1.0, 2.0, 3.0, 4.0);
assert_eq!(m.rows(), 2);
assert_eq!(m.cols(), 2);
}
#[test]
fn test_from_vec() {
let data = vec![1.0, 3.0, 2.0, 4.0]; // Column major: [col0_row0, col0_row1, col1_row0, col1_row1]
let m = Matrix::from_vec(data, 2, 2);
let expected = make_f64_matrix(1.0, 2.0, 3.0, 4.0);
assert_eq!(m, expected);
assert_eq!(m.to_vec(), vec![1.0, 3.0, 2.0, 4.0]);
}
#[test]
fn test_from_rows_vec() {
// Representing:
// 1 2 3
// 4 5 6
let rows_data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let matrix = Matrix::from_rows_vec(rows_data, 2, 3);
let data = vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]; // Column-major
let expected = Matrix::from_vec(data, 2, 3);
assert_eq!(matrix, expected);
}
// Helper function to create a basic Matrix for testing
fn static_test_matrix() -> Matrix<i32> {
// Column-major data:
// 1 4 7
// 2 5 8
// 3 6 9
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
Matrix::from_vec(data, 3, 3)
}
// Another helper for a different size
fn static_test_matrix_2x4() -> Matrix<i32> {
// Column-major data:
// 1 3 5 7
// 2 4 6 8
let data = vec![1, 2, 3, 4, 5, 6, 7, 8];
Matrix::from_vec(data, 2, 4)
}
#[test]
fn test_from_vec_basic() {
let data = vec![1, 2, 3, 4, 5, 6]; // 2 rows, 3 cols (column-major)
let matrix = Matrix::from_vec(data, 2, 3);
assert_eq!(matrix.rows(), 2);
assert_eq!(matrix.cols(), 3);
assert_eq!(matrix.data(), &[1, 2, 3, 4, 5, 6]);
// Check some elements
assert_eq!(matrix[(0, 0)], 1); // First row, first col
assert_eq!(matrix[(1, 0)], 2); // Second row, first col
assert_eq!(matrix[(0, 1)], 3); // First row, second col
assert_eq!(matrix[(1, 2)], 6); // Second row, third col
}
#[test]
fn test_transpose() {
let matrix = static_test_matrix();
let transposed = matrix.transpose();
let round_triped = transposed.transpose();
assert_eq!(
round_triped, matrix,
"Transposing twice should return original matrix"
);
for r in 0..matrix.rows() {
for c in 0..matrix.cols() {
assert_eq!(matrix[(r, c)], transposed[(c, r)]);
}
}
}
#[test]
fn test_transpose_big() {
let data: Vec<i32> = (1..=20000).collect(); //
let matrix = Matrix::from_vec(data, 100, 200);
let transposed = matrix.transpose();
assert_eq!(transposed.rows(), 200);
assert_eq!(transposed.cols(), 100);
assert_eq!(transposed.data().len(), 20000);
assert_eq!(transposed[(0, 0)], 1);
let round_trip = transposed.transpose();
assert_eq!(
round_trip, matrix,
"Transposing back should return original matrix"
);
}
#[test]
#[should_panic(expected = "data length mismatch")]
fn test_from_vec_wrong_length() {
let data = vec![1, 2, 3, 4, 5]; // Should be 6 for 2x3
Matrix::from_vec(data, 2, 3);
}
#[test]
#[should_panic(expected = "need at least one row")]
fn test_from_vec_zero_rows() {
let data = vec![1, 2, 3];
Matrix::from_vec(data, 0, 3);
}
#[test]
#[should_panic(expected = "need at least one column")]
fn test_from_vec_zero_cols() {
let data = vec![1, 2, 3];
Matrix::from_vec(data, 3, 0);
}
#[test]
fn test_from_cols_basic() {
// Representing:
// 1 4 7
// 2 5 8
// 3 6 9
let cols_data = vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]];
let matrix = Matrix::from_cols(cols_data);
assert_eq!(matrix.rows(), 3);
assert_eq!(matrix.cols(), 3);
// Internal data should be column-major
assert_eq!(matrix.data(), &[1, 2, 3, 4, 5, 6, 7, 8, 9]);
// Check some elements
assert_eq!(matrix[(0, 0)], 1);
assert_eq!(matrix[(2, 0)], 3);
assert_eq!(matrix[(1, 1)], 5);
assert_eq!(matrix[(0, 2)], 7);
}
#[test]
fn test_from_cols_1x1() {
let cols_data = vec![vec![42]];
let matrix = Matrix::from_cols(cols_data);
assert_eq!(matrix.rows(), 1);
assert_eq!(matrix.cols(), 1);
assert_eq!(matrix.data(), &[42]);
assert_eq!(matrix[(0, 0)], 42);
}
#[test]
#[should_panic(expected = "need at least one column")]
fn test_from_cols_empty_cols() {
let empty_cols: Vec<Vec<i32>> = vec![];
Matrix::from_cols(empty_cols);
}
#[test]
#[should_panic(expected = "need at least one row")]
fn test_from_cols_empty_rows() {
let empty_row: Vec<Vec<String>> = vec![vec![], vec![]];
Matrix::from_cols(empty_row);
}
#[test]
#[should_panic(expected = "col 1 has len 2, expected 3")]
fn test_from_cols_mismatched_lengths() {
let cols_data = vec![vec![1, 2, 3], vec![4, 5], vec![6, 7, 8]];
Matrix::from_cols(cols_data);
}
#[test]
fn test_getters() {
let matrix = static_test_matrix();
assert_eq!(matrix.rows(), 3);
assert_eq!(matrix.cols(), 3);
assert_eq!(matrix.data(), &[1, 2, 3, 4, 5, 6, 7, 8, 9]);
}
#[test]
fn test_index_and_get() {
let matrix = static_test_matrix();
assert_eq!(matrix[(0, 0)], 1);
assert_eq!(matrix[(1, 1)], 5);
assert_eq!(matrix[(2, 2)], 9);
assert_eq!(*matrix.get(0, 0), 1);
assert_eq!(*matrix.get(1, 1), 5);
assert_eq!(*matrix.get(2, 2), 9);
}
#[test]
#[should_panic(expected = "index out of bounds")]
fn test_index_out_of_bounds_row_alt() {
let matrix = static_test_matrix();
let _ = matrix[(3, 0)];
}
#[test]
#[should_panic(expected = "index out of bounds")]
fn test_index_out_of_bounds_col_alt() {
let matrix = static_test_matrix();
let _ = matrix[(0, 3)];
}
#[test]
fn test_index_mut_and_get_mut() {
let mut matrix = static_test_matrix();
matrix[(0, 0)] = 10;
matrix[(1, 1)] = 20;
matrix[(2, 2)] = 30;
assert_eq!(matrix[(0, 0)], 10);
assert_eq!(matrix[(1, 1)], 20);
assert_eq!(matrix[(2, 2)], 30);
*matrix.get_mut(0, 1) = 15;
*matrix.get_mut(2, 1) = 25;
assert_eq!(matrix[(0, 1)], 15);
assert_eq!(matrix[(2, 1)], 25);
assert_eq!(matrix.data(), &[10, 2, 3, 15, 20, 25, 7, 8, 30]);
}
#[test]
#[should_panic(expected = "index out of bounds")]
fn test_index_mut_out_of_bounds_row() {
let mut matrix = static_test_matrix();
matrix[(3, 0)] = 99;
}
#[test]
#[should_panic(expected = "index out of bounds")]
fn test_index_mut_out_of_bounds_col() {
let mut matrix = static_test_matrix();
matrix[(0, 3)] = 99;
}
#[test]
fn test_row() {
let ma = static_test_matrix();
assert_eq!(ma.row(0), &[1, 4, 7]);
assert_eq!(ma.row(1), &[2, 5, 8]);
assert_eq!(ma.row(2), &[3, 6, 9]);
}
#[test]
fn test_row_copy_from_slice() {
let mut ma = static_test_matrix();
let new_row = vec![10, 20, 30];
ma.row_copy_from_slice(1, &new_row);
assert_eq!(ma.row(1), &[10, 20, 30]);
}
#[test]
#[should_panic(expected = "row index 4 out of bounds for 3 rows")]
fn test_row_copy_from_slice_out_of_bounds() {
let mut ma = static_test_matrix();
let new_row = vec![10, 20, 30];
ma.row_copy_from_slice(4, &new_row);
}
#[test]
#[should_panic(expected = "row index 3 out of bounds for 3 rows")]
fn test_row_out_of_bounds_index() {
let ma = static_test_matrix();
ma.row(3);
}
#[test]
#[should_panic(expected = "input slice length 2 does not match number of columns 3")]
fn test_row_copy_from_slice_wrong_length() {
let mut ma = static_test_matrix();
let new_row = vec![10, 20]; // Only 2 elements, but row length is 3
ma.row_copy_from_slice(1, &new_row);
}
#[test]
fn test_shape() {
let ma = static_test_matrix_2x4();
assert_eq!(ma.shape(), (2, 4));
assert_eq!(ma.rows(), 2);
assert_eq!(ma.cols(), 4);
}
#[test]
fn test_repeat_rows() {
let ma = static_test_matrix();
// Returns a new matrix where row 0 of `self` is repeated `n` times.
let repeated = ma.repeat_rows(3);
// assert all rows are equal to the first row
for r in 0..repeated.rows() {
assert_eq!(repeated.row(r), ma.row(0));
}
}
#[test]
#[should_panic(expected = "row index 3 out of bounds for 3 rows")]
fn test_row_out_of_bounds() {
let ma = static_test_matrix();
ma.row(3);
}
#[test]
fn test_column() {
let matrix = static_test_matrix_2x4();
assert_eq!(matrix.column(0), &[1, 2]);
assert_eq!(matrix.column(1), &[3, 4]);
assert_eq!(matrix.column(2), &[5, 6]);
assert_eq!(matrix.column(3), &[7, 8]);
}
#[test]
#[should_panic(expected = "column index 4 out of bounds for 4 columns")]
fn test_column_out_of_bounds() {
let matrix = static_test_matrix_2x4();
matrix.column(4);
}
#[test]
fn test_column_mut() {
let mut matrix = static_test_matrix_2x4();
let col1_mut = matrix.column_mut(1);
col1_mut[0] = 30;
col1_mut[1] = 40;
let col3_mut = matrix.column_mut(3);
col3_mut[0] = 70;
// Check changes via indexing
assert_eq!(matrix[(0, 1)], 30);
assert_eq!(matrix[(1, 1)], 40);
assert_eq!(matrix[(0, 3)], 70);
assert_eq!(matrix[(1, 3)], 8); // Unchanged
assert_eq!(matrix.data(), &[1, 2, 30, 40, 5, 6, 70, 8]);
}
#[test]
#[should_panic(expected = "column index 4 out of bounds for 4 columns")]
fn test_column_mut_out_of_bounds() {
let mut matrix = static_test_matrix_2x4();
matrix.column_mut(4);
}
#[test]
fn test_iter_columns() {
let matrix = static_test_matrix_2x4();
let cols: Vec<&[i32]> = matrix.iter_columns().collect();
assert_eq!(cols.len(), 4);
assert_eq!(cols[0], &[1, 2]);
assert_eq!(cols[1], &[3, 4]);
assert_eq!(cols[2], &[5, 6]);
assert_eq!(cols[3], &[7, 8]);
}
#[test]
fn test_iter_rows() {
let matrix = static_test_matrix_2x4();
let rows: Vec<Vec<i32>> = matrix
.iter_rows()
.map(|row| row.iter().cloned().collect())
.collect();
assert_eq!(rows.len(), 2);
assert_eq!(rows[0], vec![1, 3, 5, 7]);
assert_eq!(rows[1], vec![2, 4, 6, 8]);
}
// test data_mut
#[test]
fn test_data_mut() {
let mut matrix = static_test_matrix();
let data_mut = matrix.data_mut();
data_mut[0] = 10;
data_mut[1] = 20;
assert_eq!(matrix[(0, 0)], 10);
assert_eq!(matrix[(1, 0)], 20);
}
#[test]
fn test_matrix_row_get_and_iter() {
let matrix = static_test_matrix_2x4();
let row0 = matrix.iter_rows().next().unwrap();
assert_eq!(*row0.get(0), 1);
assert_eq!(*row0.get(1), 3);
assert_eq!(*row0.get(3), 7);
let row0_vec: Vec<i32> = row0.iter().cloned().collect();
assert_eq!(row0_vec, vec![1, 3, 5, 7]);
let row1 = matrix.iter_rows().nth(1).unwrap();
assert_eq!(*row1.get(0), 2);
assert_eq!(*row1.get(2), 6);
let row1_vec: Vec<i32> = row1.iter().cloned().collect();
assert_eq!(row1_vec, vec![2, 4, 6, 8]);
}
#[test]
fn test_swap_columns() {
let mut matrix = static_test_matrix();
matrix.swap_columns(0, 2); // swap first and last
assert_eq!(matrix.rows(), 3);
assert_eq!(matrix.cols(), 3);
assert_eq!(matrix[(0, 0)], 7);
assert_eq!(matrix[(1, 0)], 8);
assert_eq!(matrix[(2, 0)], 9);
assert_eq!(matrix[(0, 1)], 4); // Middle col unchanged
assert_eq!(matrix[(1, 1)], 5);
assert_eq!(matrix[(2, 1)], 6);
assert_eq!(matrix[(0, 2)], 1);
assert_eq!(matrix[(1, 2)], 2);
assert_eq!(matrix[(2, 2)], 3);
// swap the same column (should do nothing)
let original_data = matrix.data().to_vec();
matrix.swap_columns(1, 1);
assert_eq!(matrix.data(), &original_data); // Data should be identical
// Check underlying data (column-major) after swap(0, 2)
assert_eq!(matrix.data(), &[7, 8, 9, 4, 5, 6, 1, 2, 3]);
}
#[test]
#[should_panic(expected = "column index c2=3 out of bounds for 3 columns")]
fn test_swap_columns_out_of_bounds() {
let mut matrix = static_test_matrix();
matrix.swap_columns(0, 3);
}
#[test]
fn test_delete_column() {
let mut matrix = static_test_matrix_2x4();
matrix.delete_column(1); // Delete the second column (index 1)
assert_eq!(matrix.rows(), 2);
assert_eq!(matrix.cols(), 3);
assert_eq!(matrix[(0, 0)], 1);
assert_eq!(matrix[(1, 0)], 2);
assert_eq!(matrix[(0, 1)], 5);
assert_eq!(matrix[(1, 1)], 6);
assert_eq!(matrix[(0, 2)], 7);
assert_eq!(matrix[(1, 2)], 8);
// check underlying data
assert_eq!(matrix.data(), &[1, 2, 5, 6, 7, 8]);
// Delete the first column
matrix.delete_column(0);
assert_eq!(matrix.rows(), 2);
assert_eq!(matrix.cols(), 2);
assert_eq!(matrix.data(), &[5, 6, 7, 8]);
// Delete the last column
matrix.delete_column(1);
assert_eq!(matrix.rows(), 2);
assert_eq!(matrix.cols(), 1);
assert_eq!(matrix.data(), &[5, 6]);
// Delete the only column
matrix.delete_column(0);
// Should be empty
assert_eq!(matrix.rows(), 2); // Rows stay the same
assert_eq!(matrix.cols(), 0); // Cols becomes 0
assert_eq!(matrix.data(), &[]);
}
#[test]
#[should_panic(expected = "column index 4 out of bounds for 4 columns")]
fn test_delete_column_out_of_bounds() {
let mut matrix = static_test_matrix_2x4();
matrix.delete_column(4);
}
#[test]
fn test_delete_row() {
let mut matrix = static_test_matrix();
matrix.delete_row(1); // Delete the second row
assert_eq!(matrix.rows(), 2);
assert_eq!(matrix.cols(), 3);
assert_eq!(matrix[(0, 0)], 1);
assert_eq!(matrix[(1, 0)], 3);
assert_eq!(matrix[(0, 1)], 4);
assert_eq!(matrix[(1, 1)], 6);
assert_eq!(matrix[(0, 2)], 7);
assert_eq!(matrix[(1, 2)], 9);
// check underlying data (column-major)
assert_eq!(matrix.data(), &[1, 3, 4, 6, 7, 9]);
// Delete the first row
matrix.delete_row(0);
// Should be:
// 3 6 9
assert_eq!(matrix.rows(), 1);
assert_eq!(matrix.cols(), 3);
assert_eq!(matrix.data(), &[3, 6, 9]);
// Delete the last (and only) row
matrix.delete_row(0);
// Should be empty
assert_eq!(matrix.rows(), 0); // Rows becomes 0
assert_eq!(matrix.cols(), 3); // Cols stay the same
assert_eq!(matrix.data(), &[]);
}
#[test]
#[should_panic(expected = "row index 3 out of bounds for 3 rows")]
fn test_delete_row_out_of_bounds() {
let mut matrix = static_test_matrix();
matrix.delete_row(3);
}
#[test]
fn test_add_column() {
let mut matrix = static_test_matrix_2x4();
let new_col = vec![9, 10];
matrix.add_column(2, new_col);
assert_eq!(matrix.rows(), 2);
assert_eq!(matrix.cols(), 5);
assert_eq!(matrix[(0, 0)], 1);
assert_eq!(matrix[(1, 0)], 2);
assert_eq!(matrix[(0, 1)], 3);
assert_eq!(matrix[(1, 1)], 4);
assert_eq!(matrix[(0, 2)], 9);
assert_eq!(matrix[(1, 2)], 10);
assert_eq!(matrix[(0, 3)], 5); // Shifted
assert_eq!(matrix[(1, 3)], 6);
assert_eq!(matrix[(0, 4)], 7); // Shifted
assert_eq!(matrix[(1, 4)], 8);
// Check underlying data
assert_eq!(matrix.data(), &[1, 2, 3, 4, 9, 10, 5, 6, 7, 8]);
// Add a column at the beginning
let new_col_start = vec![11, 12];
matrix.add_column(0, new_col_start);
assert_eq!(matrix.rows(), 2);
assert_eq!(matrix.cols(), 6);
assert_eq!(matrix[(0, 0)], 11);
assert_eq!(matrix[(1, 0)], 12);
assert_eq!(matrix.data(), &[11, 12, 1, 2, 3, 4, 9, 10, 5, 6, 7, 8]);
// Add a column at the end
let new_col_end = vec![13, 14];
matrix.add_column(6, new_col_end);
assert_eq!(matrix.rows(), 2);
assert_eq!(matrix.cols(), 7);
assert_eq!(matrix[(0, 6)], 13);
assert_eq!(matrix[(1, 6)], 14);
assert_eq!(
matrix.data(),
&[11, 12, 1, 2, 3, 4, 9, 10, 5, 6, 7, 8, 13, 14]
);
}
#[test]
#[should_panic(expected = "add_column index 5 out of bounds for 4 columns")]
fn test_add_column_out_of_bounds() {
let mut matrix = static_test_matrix_2x4();
let new_col = vec![9, 10];
matrix.add_column(5, new_col); // Index 5 is out of bounds for 4 columns
}
#[test]
#[should_panic(expected = "column length mismatch")]
fn test_add_column_length_mismatch() {
let mut matrix = static_test_matrix_2x4();
let new_col = vec![9, 10, 11]; // Wrong length
matrix.add_column(0, new_col);
}
#[test]
fn test_add_row() {
let mut matrix = static_test_matrix_2x4();
let new_row = vec![9, 10, 11, 12];
matrix.add_row(1, new_row);
assert_eq!(matrix.rows(), 3);
assert_eq!(matrix.cols(), 4);
assert_eq!(matrix[(0, 0)], 1);
assert_eq!(matrix[(0, 1)], 3);
assert_eq!(matrix[(0, 2)], 5);
assert_eq!(matrix[(0, 3)], 7);
assert_eq!(matrix[(1, 0)], 9);
assert_eq!(matrix[(1, 1)], 10);
assert_eq!(matrix[(1, 2)], 11);
assert_eq!(matrix[(1, 3)], 12);
assert_eq!(matrix[(2, 0)], 2);
assert_eq!(matrix[(2, 1)], 4);
assert_eq!(matrix[(2, 2)], 6);
assert_eq!(matrix[(2, 3)], 8);
// Check underlying data (column-major)
assert_eq!(matrix.data(), &[1, 9, 2, 3, 10, 4, 5, 11, 6, 7, 12, 8]);
// Add a row at the beginning
let new_row_start = vec![13, 14, 15, 16];
matrix.add_row(0, new_row_start);
assert_eq!(matrix.rows(), 4);
assert_eq!(matrix.cols(), 4);
assert_eq!(matrix[(0, 0)], 13);
assert_eq!(matrix[(0, 1)], 14);
assert_eq!(matrix[(0, 2)], 15);
assert_eq!(matrix[(0, 3)], 16);
assert_eq!(matrix[(1, 0)], 1);
assert_eq!(matrix[(2, 1)], 10);
assert_eq!(matrix[(3, 3)], 8);
// Add a row at the end
let new_row_end = vec![17, 18, 19, 20];
matrix.add_row(4, new_row_end);
assert_eq!(matrix.rows(), 5);
assert_eq!(matrix.cols(), 4);
assert_eq!(matrix[(4, 0)], 17);
assert_eq!(matrix[(4, 3)], 20);
}
#[test]
#[should_panic(expected = "add_row index 3 out of bounds for 2 rows")]
fn test_add_row_out_of_bounds() {
let mut matrix = static_test_matrix_2x4();
let new_row = vec![9, 10, 11, 12];
matrix.add_row(3, new_row); // Index 3 is out of bounds for 2 rows
}
#[test]
#[should_panic(expected = "row length mismatch")]
fn test_add_row_length_mismatch() {
let mut matrix = static_test_matrix_2x4();
let new_row = vec![9, 10, 11]; // Wrong length
matrix.add_row(0, new_row);
}
#[test]
fn test_elementwise_add() {
let matrix1 = static_test_matrix();
let matrix2 = Matrix::from_vec(vec![9, 8, 7, 6, 5, 4, 3, 2, 1], 3, 3);
let result = &matrix1 + &matrix2;
assert_eq!(result.rows(), 3);
assert_eq!(result.cols(), 3);
// Expect all 10s
assert_eq!(result.data(), &[10, 10, 10, 10, 10, 10, 10, 10, 10]);
assert_eq!(result[(0, 0)], 10);
assert_eq!(result[(1, 1)], 10);
assert_eq!(result[(2, 2)], 10);
}
#[test]
fn test_elementwise_sub() {
let matrix1 = static_test_matrix();
let matrix2 = Matrix::from_vec(vec![1, 1, 1, 2, 2, 2, 3, 3, 3], 3, 3);
let result = &matrix1 - &matrix2;
assert_eq!(result.rows(), 3);
assert_eq!(result.cols(), 3);
assert_eq!(result.data(), &[0, 1, 2, 2, 3, 4, 4, 5, 6]);
assert_eq!(result[(0, 0)], 0);
assert_eq!(result[(1, 1)], 3);
assert_eq!(result[(2, 2)], 6);
}
#[test]
fn test_elementwise_mul() {
let matrix1 = static_test_matrix();
let matrix2 = Matrix::from_vec(vec![1, 2, 3, 1, 2, 3, 1, 2, 3], 3, 3);
let result = &matrix1 * &matrix2;
assert_eq!(result.rows(), 3);
assert_eq!(result.cols(), 3);
// Expected
assert_eq!(result.data(), &[1, 4, 9, 4, 10, 18, 7, 16, 27]);
assert_eq!(result[(0, 0)], 1);
assert_eq!(result[(1, 1)], 10);
assert_eq!(result[(2, 2)], 27);
}
#[test]
fn test_elementwise_div() {
let matrix1 = static_test_matrix();
let matrix2 = Matrix::from_vec(vec![1, 1, 1, 2, 2, 2, 7, 8, 9], 3, 3);
let result = &matrix1 / &matrix2; // Integer division
assert_eq!(result.rows(), 3);
assert_eq!(result.cols(), 3);
assert_eq!(result.data(), &[1, 2, 3, 2, 2, 3, 1, 1, 1]);
assert_eq!(result[(0, 0)], 1);
assert_eq!(result[(1, 1)], 2);
assert_eq!(result[(2, 2)], 1);
}
#[test]
#[should_panic(expected = "Row count mismatch: left has 3 rows, right has 2 rows")]
fn test_elementwise_op_row_mismatch() {
let matrix1 = static_test_matrix();
let matrix2 = static_test_matrix_2x4();
let _ = &matrix1 + &matrix2; // Should panic
}
#[test]
#[should_panic(expected = "Row count mismatch: left has 3 rows, right has 2 ro")]
fn test_elementwise_op_col_mismatch() {
let matrix1 = static_test_matrix();
let matrix2 = static_test_matrix_2x4();
let _ = &matrix1 * &matrix2; // Should panic
}
#[test]
fn test_bitwise_and() {
let data1 = vec![true, false, true, false, true, false];
let data2 = vec![true, true, false, false, true, true];
let matrix1 = BoolMatrix::from_vec(data1, 2, 3);
let matrix2 = BoolMatrix::from_vec(data2, 2, 3);
let expected_data = vec![true, false, false, false, true, false];
let expected_matrix = BoolMatrix::from_vec(expected_data, 2, 3);
let result = &matrix1 & &matrix2;
assert_eq!(result, expected_matrix);
}
#[test]
fn test_bitwise_or() {
let data1 = vec![true, false, true, false, true, false];
let data2 = vec![true, true, false, false, true, true];
let matrix1 = BoolMatrix::from_vec(data1, 2, 3);
let matrix2 = BoolMatrix::from_vec(data2, 2, 3);
let expected_data = vec![true, true, true, false, true, true];
let expected_matrix = BoolMatrix::from_vec(expected_data, 2, 3);
let result = &matrix1 | &matrix2;
assert_eq!(result, expected_matrix);
}
#[test]
fn test_bitwise_xor() {
let data1 = vec![true, false, true, false, true, false];
let data2 = vec![true, true, false, false, true, true];
let matrix1 = BoolMatrix::from_vec(data1, 2, 3);
let matrix2 = BoolMatrix::from_vec(data2, 2, 3);
let expected_data = vec![false, true, true, false, false, true];
let expected_matrix = BoolMatrix::from_vec(expected_data, 2, 3);
let result = &matrix1 ^ &matrix2;
assert_eq!(result, expected_matrix);
}
#[test]
fn test_bitwise_not() {
let data = vec![true, false, true, false, true, false];
let matrix = BoolMatrix::from_vec(data, 2, 3);
let expected_data = vec![false, true, false, true, false, true];
let expected_matrix = BoolMatrix::from_vec(expected_data, 2, 3);
let result = !matrix; // Not consumes the matrix
assert_eq!(result, expected_matrix);
}
#[test]
#[should_panic(expected = "Column count mismatch: left has 2 columns, right has 3 columns")]
fn test_bitwise_op_row_mismatch() {
let data1 = vec![true, false, true, false];
let data2 = vec![true, true, false, false, true, true];
let matrix1 = BoolMatrix::from_vec(data1, 2, 2);
let matrix2 = BoolMatrix::from_vec(data2, 2, 3);
let _ = &matrix1 & &matrix2; // Should panic
}
#[test]
#[should_panic(expected = "Column count mismatch: left has 2 columns, right has 3 columns")]
fn test_bitwise_op_col_mismatch() {
let data1 = vec![true, false, true, false];
let data2 = vec![true, true, false, false, true, true];
let matrix1 = BoolMatrix::from_vec(data1, 2, 2);
let matrix2 = BoolMatrix::from_vec(data2, 2, 3);
let _ = &matrix1 | &matrix2; // Should panic
}
// Test with String type (requires Clone, PartialEq)
#[test]
fn test_string_matrix() {
let data = vec![
"a".to_string(),
"b".to_string(),
"c".to_string(),
"d".to_string(),
];
let matrix = StringMatrix::from_vec(data.clone(), 2, 2);
assert_eq!(matrix[(0, 0)], "a".to_string());
assert_eq!(matrix[(1, 0)], "b".to_string());
assert_eq!(matrix[(0, 1)], "c".to_string());
assert_eq!(matrix[(1, 1)], "d".to_string());
// Test modification
let mut matrix = matrix;
matrix[(0, 0)] = "hello".to_string();
assert_eq!(matrix[(0, 0)], "hello".to_string());
// Test add_column
let new_col = vec!["e".to_string(), "f".to_string()];
matrix.add_column(1, new_col);
assert_eq!(matrix.rows(), 2);
assert_eq!(matrix.cols(), 3);
assert_eq!(matrix[(0, 0)], "hello".to_string());
assert_eq!(matrix[(1, 0)], "b".to_string());
assert_eq!(matrix[(0, 1)], "e".to_string()); // New col
assert_eq!(matrix[(1, 1)], "f".to_string()); // New col
assert_eq!(matrix[(0, 2)], "c".to_string()); // Shifted
assert_eq!(matrix[(1, 2)], "d".to_string()); // Shifted
// Test add_row (requires Clone)
let new_row = vec!["g".to_string(), "h".to_string(), "i".to_string()];
matrix.add_row(0, new_row);
assert_eq!(matrix.rows(), 3);
assert_eq!(matrix.cols(), 3);
assert_eq!(matrix[(0, 0)], "g".to_string());
assert_eq!(matrix[(0, 1)], "h".to_string());
assert_eq!(matrix[(0, 2)], "i".to_string());
assert_eq!(matrix[(1, 0)], "hello".to_string()); // Shifted
assert_eq!(matrix[(2, 2)], "d".to_string()); // Shifted
}
#[test]
fn test_float_matrix_ops() {
let data1 = vec![1.0, 2.0, 3.0, 4.0];
let data2 = vec![0.5, 1.5, 2.5, 3.5];
let matrix1 = FloatMatrix::from_vec(data1, 2, 2);
let matrix2 = FloatMatrix::from_vec(data2, 2, 2);
let sum = &matrix1 + &matrix2;
let diff = &matrix1 - &matrix2;
let prod = &matrix1 * &matrix2;
let div = &matrix1 / &matrix2;
assert_eq!(sum.data(), &[1.5, 3.5, 5.5, 7.5]);
assert_eq!(diff.data(), &[0.5, 0.5, 0.5, 0.5]);
assert_eq!(prod.data(), &[0.5, 3.0, 7.5, 14.0]);
// Check divs (col-major): [2.0, 1.333..., 1.2, 1.14...]
assert_eq!(div.rows(), 2);
assert_eq!(div.cols(), 2);
assert!((div[(0, 0)] - 1.0 / 0.5).abs() < 1e-9); // 2.0
assert!((div[(1, 0)] - 2.0 / 1.5).abs() < 1e-9); // 1.333...
assert!((div[(0, 1)] - 3.0 / 2.5).abs() < 1e-9); // 1.2
assert!((div[(1, 1)] - 4.0 / 3.5).abs() < 1e-9); // 1.14...
}
fn create_test_matrix_i32() -> Matrix<i32> {
Matrix::from_cols(vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]])
}
#[test]
fn test_matrix_swap_columns_directly() {
let mut matrix = create_test_matrix_i32();
// Store the initial state of the columns we intend to swap AND one that shouldn't change
let initial_col0_data = matrix.column(0).to_vec(); // Should be [1, 2, 3]
let initial_col1_data = matrix.column(1).to_vec(); // Should be [4, 5, 6]
let initial_col2_data = matrix.column(2).to_vec(); // Should be [7, 8, 9]
// Perform the swap directly on the matrix
matrix.swap_columns(0, 2); // Swap column 0 and column 2
// Verify the dimensions are unchanged
assert_eq!(matrix.rows(), 3, "Matrix rows should remain unchanged");
assert_eq!(matrix.cols(), 3, "Matrix cols should remain unchanged");
// Verify the column that was NOT swapped is unchanged
assert_eq!(
matrix.column(1),
initial_col1_data.as_slice(), // Comparing slice to slice
"Column 1 data should be unchanged"
);
// Verify the data swap occurred correctly using the COLUMN ACCESSOR
// The data originally at index 0 should now be at index 2
assert_eq!(
matrix.column(2),
initial_col0_data.as_slice(),
"Column 2 should now contain the original data from column 0"
);
// The data originally at index 2 should now be at index 0
assert_eq!(
matrix.column(0),
initial_col2_data.as_slice(),
"Column 0 should now contain the original data from column 2"
);
// Verify the underlying raw data vector
assert_eq!(
matrix.data(),
&[7, 8, 9, 4, 5, 6, 1, 2, 3],
"Underlying data vector is incorrect after swap"
);
// Test swapping with self (should be a no-op)
let state_before_self_swap = matrix.clone();
matrix.swap_columns(1, 1);
assert_eq!(
matrix, state_before_self_swap,
"Swapping a column with itself should not change the matrix"
);
// Test swapping adjacent columns
let mut matrix2 = create_test_matrix_i32();
let initial_col0_data_m2 = matrix2.column(0).to_vec();
let initial_col1_data_m2 = matrix2.column(1).to_vec();
matrix2.swap_columns(0, 1);
assert_eq!(matrix2.column(0), initial_col1_data_m2.as_slice());
assert_eq!(matrix2.column(1), initial_col0_data_m2.as_slice());
assert_eq!(matrix2.data(), &[4, 5, 6, 1, 2, 3, 7, 8, 9]);
}
// Test broadcastable operations
#[test]
fn test_comparision_broadcast() {
let matrix = static_test_matrix();
// test all > 0
let result = matrix.gt_elem(0).into_vec();
let expected = vec![true; result.len()];
assert_eq!(result, expected);
let ma = static_test_matrix();
let mb = static_test_matrix();
let result = ma.eq_elem(mb);
assert!(result.all());
let result = matrix.lt_elem(1e10 as i32).all();
assert!(result);
for i in 0..matrix.rows() {
for j in 0..matrix.cols() {
let vx = matrix[(i, j)];
let c = &(matrix.le_elem(vx)) & &(matrix.ge_elem(vx));
assert_eq!(c.count(), 1);
}
}
}
#[test]
fn test_arithmetic_broadcast() {
let matrix = static_test_matrix();
let result = &matrix + 1;
for i in 0..matrix.rows() {
for j in 0..matrix.cols() {
assert_eq!(result[(i, j)], matrix[(i, j)] + 1);
}
}
// test mul and div
let result = &matrix * 2;
let result2 = &matrix / 2;
for i in 0..matrix.rows() {
for j in 0..matrix.cols() {
assert_eq!(result[(i, j)], matrix[(i, j)] * 2);
assert_eq!(result2[(i, j)], matrix[(i, j)] / 2);
}
}
}
#[test]
fn test_matrix_zeros_ones_filled() {
// Test zeros
let m = Matrix::<f64>::zeros(2, 3);
assert_eq!(m.rows(), 2);
assert_eq!(m.cols(), 3);
assert_eq!(m.data(), &[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
// Test ones
let m = Matrix::<f64>::ones(3, 2);
assert_eq!(m.rows(), 3);
assert_eq!(m.cols(), 2);
assert_eq!(m.data(), &[1.0, 1.0, 1.0, 1.0, 1.0, 1.0]);
// Test filled
let m = Matrix::<f64>::filled(2, 2, 42.5);
assert_eq!(m.rows(), 2);
assert_eq!(m.cols(), 2);
assert_eq!(m.data(), &[42.5, 42.5, 42.5, 42.5]);
// test with an integer matrix
let m = Matrix::<i32>::filled(2, 3, 7);
assert_eq!(m.rows(), 2);
assert_eq!(m.cols(), 3);
assert_eq!(m.data(), &[7, 7, 7, 7, 7, 7]);
// test with nans
let m = Matrix::nan(3, 3);
assert_eq!(m.rows(), 3);
assert_eq!(m.cols(), 3);
for &value in m.data() {
assert!(value.is_nan(), "Expected NaN, got {}", value);
}
}
#[test]
fn test_broadcast_row_to_target_shape_basic() {
let single_row_matrix = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0], 1, 3);
let target_rows = 5;
let target_cols = 3;
let broadcasted = single_row_matrix.broadcast_row_to_target_shape(target_rows, target_cols);
assert_eq!(broadcasted.rows(), target_rows);
assert_eq!(broadcasted.cols(), target_cols);
for r in 0..target_rows {
assert_eq!(broadcasted.row(r), vec![1.0, 2.0, 3.0]);
}
}
#[test]
fn test_broadcast_row_to_target_shape_single_row() {
let single_row_matrix = Matrix::from_rows_vec(vec![10.0, 20.0], 1, 2);
let target_rows = 1;
let target_cols = 2;
let broadcasted = single_row_matrix.broadcast_row_to_target_shape(target_rows, target_cols);
assert_eq!(broadcasted.rows(), target_rows);
assert_eq!(broadcasted.cols(), target_cols);
assert_eq!(broadcasted.row(0), vec![10.0, 20.0]);
}
#[test]
#[should_panic(
expected = "broadcast_row_to_target_shape can only be called on a 1-row matrix."
)]
fn test_broadcast_row_to_target_shape_panic_not_1_row() {
let multi_row_matrix = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
multi_row_matrix.broadcast_row_to_target_shape(3, 2);
}
#[test]
#[should_panic(
expected = "Column count mismatch for broadcasting: source has 3 columns, target has 4 columns."
)]
fn test_broadcast_row_to_target_shape_panic_col_mismatch() {
let single_row_matrix = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0], 1, 3);
single_row_matrix.broadcast_row_to_target_shape(5, 4);
}
}