//! A simple column-major Matrix implementation with element-wise operations. use std::ops::{Add, BitAnd, BitOr, BitXor, Div, Index, IndexMut, Mul, Not, Sub}; /// A column‑major 2D matrix of `T`. Index as `Array(row, column)`. #[derive(Debug, Clone, PartialEq, Eq)] pub struct Matrix { rows: usize, cols: usize, data: Vec, } impl Matrix { /// Build from columns (each inner Vec is one column) pub fn from_cols(cols_data: Vec>) -> Self { let cols = cols_data.len(); assert!(cols > 0, "need at least one column"); // Handle empty cols_data let rows = cols_data.get(0).map_or(0, |c| c.len()); // Allow 0-row matrices if columns are empty, but not 0-col matrices if rows > 0 assert!( rows > 0 || cols == 0, "need at least one row if columns exist" ); for (i, col) in cols_data.iter().enumerate() { assert!( col.len() == rows, "col {} has len {}, expected {}", i, col.len(), rows ); } // Flatten column data directly let data = cols_data.into_iter().flatten().collect(); Matrix { rows, cols, data } } /// Build from a flat Vec, assuming column-major order. pub fn from_vec(data: Vec, rows: usize, cols: usize) -> Self { assert!( rows > 0 || cols == 0, "need at least one row if columns exist" ); assert!( cols > 0 || rows == 0, "need at least one column if rows exist" ); if rows * cols != 0 { // Only assert length if matrix is non-empty assert_eq!( data.len(), rows * cols, "data length mismatch: expected {}, got {}", rows * cols, data.len() ); } else { assert!(data.is_empty(), "data must be empty for 0-sized matrix"); } Matrix { rows, cols, data } } pub fn data(&self) -> &[T] { &self.data } pub fn data_mut(&mut self) -> &mut [T] { &mut self.data } /// Consumes the Matrix and returns its underlying data Vec. pub fn into_vec(self) -> Vec { self.data } /// Creates a new `Vec` containing the matrix data (cloned). pub fn to_vec(&self) -> Vec { self.data.clone() } pub fn rows(&self) -> usize { self.rows } pub fn cols(&self) -> usize { self.cols } /// Get element reference (immutable). Panics on out-of-bounds. pub fn get(&self, r: usize, c: usize) -> &T { &self[(r, c)] } /// Get element reference (mutable). Panics on out-of-bounds. pub fn get_mut(&mut self, r: usize, c: usize) -> &mut T { &mut self[(r, c)] } #[inline] pub fn column(&self, c: usize) -> &[T] { assert!( c < self.cols, "column index {} out of bounds for {} columns", c, self.cols ); let start = c * self.rows; &self.data[start..start + self.rows] } #[inline] pub fn column_mut(&mut self, c: usize) -> &mut [T] { assert!( c < self.cols, "column index {} out of bounds for {} columns", c, self.cols ); let start = c * self.rows; &mut self.data[start..start + self.rows] } pub fn iter_columns(&self) -> impl Iterator { (0..self.cols).map(move |c| self.column(c)) } pub fn iter_rows(&self) -> impl Iterator> { (0..self.rows).map(move |r| MatrixRow { matrix: self, row: r, }) } /// Swaps two columns in the matrix. Panics on out-of-bounds. pub fn swap_columns(&mut self, c1: usize, c2: usize) { assert!( c1 < self.cols, "column index c1={} out of bounds for {} columns", c1, self.cols ); assert!( c2 < self.cols, "column index c2={} out of bounds for {} columns", c2, self.cols ); if c1 == c2 || self.rows == 0 || self.cols == 0 { return; } let (start1, end1) = (c1 * self.rows, (c1 + 1) * self.rows); let (start2, end2) = (c2 * self.rows, (c2 + 1) * self.rows); if (start1 < start2 && end1 > start2) || (start2 < start1 && end2 > start1) { panic!("Cannot swap overlapping columns"); } // element-wise swap for r in 0..self.rows { self.data.swap(start1 + r, start2 + r); } } /// Deletes a column from the matrix. Panics on out-of-bounds. /// This is O(N) where N is the number of elements. pub fn delete_column(&mut self, col: usize) { assert!(col < self.cols, "column index {} out of bounds for {} columns", col, self.cols); let start = col * self.rows; self.data.drain(start..start + self.rows); // Efficient removal self.cols -= 1; } /// Deletes a row from the matrix. Panics on out-of-bounds. /// This is O(N) where N is the number of elements, as it rebuilds the data vec. pub fn delete_row(&mut self, row: usize) { assert!( row < self.rows, "row index {} out of bounds for {} rows", row, self.rows ); if self.rows == 0 { return; } // Nothing to delete let old_rows = self.rows; let new_rows = self.rows - 1; let mut new_data = Vec::with_capacity(new_rows * self.cols); for c in 0..self.cols { let col_start_old = c * old_rows; for r in 0..old_rows { if r != row { // Must clone as we are reading from the old data while building the new one new_data.push(self.data[col_start_old + r].clone()); } } } self.data = new_data; self.rows = new_rows; } } impl Matrix { /// Adds a column to the matrix at the specified index. Panics if index > cols or length mismatch. /// This is O(N) where N is the number of elements. pub fn add_column(&mut self, index: usize, column: Vec) { assert!( index <= self.cols, "add_column index {} out of bounds for {} columns", index, self.cols ); assert_eq!( column.len(), self.rows, "column length mismatch: expected {}, got {}", self.rows, column.len() ); if self.rows == 0 && self.cols == 0 { // Special case: adding first column to empty matrix assert!(index == 0, "index must be 0 for adding first column"); self.data = column; self.cols = 1; // self.rows should be correctly set by column.len() assertion } else { let insert_pos = index * self.rows; self.data.splice(insert_pos..insert_pos, column); // Efficient insertion self.cols += 1; } } /// Adds a row to the matrix at the specified index. Panics if index > rows or length mismatch. /// This is O(N) where N is the number of elements, as it rebuilds the data vec. pub fn add_row(&mut self, index: usize, row: Vec) { assert!(index <= self.rows, "add_row index {} out of bounds for {} rows", index, self.rows); assert_eq!(row.len(), self.cols, "row length mismatch: expected {} (cols), got {}", self.cols, row.len()); if self.cols == 0 && self.rows == 0 { // Special case: adding first row to empty matrix assert!(index == 0, "index must be 0 for adding first row"); // Cannot add a row if there are no columns yet. Maybe panic or change API? assert!( self.cols > 0 || row.is_empty(), "cannot add non-empty row to matrix with 0 columns" ); if row.is_empty() { return; } // Adding empty row to empty matrix is no-op } let old_rows = self.rows; let new_rows = self.rows + 1; let mut new_data = Vec::with_capacity(new_rows * self.cols); let mut row_iter = row.into_iter(); // Consume the input row vec for c in 0..self.cols { let old_col_start = c * old_rows; for r in 0..new_rows { if r == index { // Take the next element from the provided row vector new_data.push(row_iter.next().expect("Row iterator exhausted prematurely - should have been caught by length assert")); } else { // Calculate the corresponding old row index let old_r = if r < index { r } else { r - 1 }; // Must clone as we are reading from the old data while building the new one new_data.push(self.data[old_col_start + old_r].clone()); } } } self.data = new_data; self.rows = new_rows; } } impl Index<(usize, usize)> for Matrix { type Output = T; #[inline] fn index(&self, (r, c): (usize, usize)) -> &T { // Validate that the requested indices are within bounds assert!( r < self.rows && c < self.cols, "index out of bounds: ({}, {}) vs {}x{}", r, c, self.rows, self.cols ); // Compute column-major offset and return reference &self.data[c * self.rows + r] } } impl IndexMut<(usize, usize)> for Matrix { #[inline] fn index_mut(&mut self, (r, c): (usize, usize)) -> &mut T { // Validate that the requested indices are within bounds assert!( r < self.rows && c < self.cols, "index out of bounds: ({}, {}) vs {}x{}", r, c, self.rows, self.cols ); // Compute column-major offset and return mutable reference &mut self.data[c * self.rows + r] } } // --- Row Iterator Helper --- /// Represents an immutable view of a single row in the matrix. pub struct MatrixRow<'a, T> { matrix: &'a Matrix, row: usize, } impl<'a, T> MatrixRow<'a, T> { /// Returns a reference to the element at the given column in this row. pub fn get(&self, c: usize) -> &T { &self.matrix[(self.row, c)] } /// Returns an iterator over all elements in this row. pub fn iter(&self) -> impl Iterator { (0..self.matrix.cols).map(move |c| self.get(c)) } } // --- Reduction Axis Enum --- /// Specifies the axis along which to perform a reduction operation. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Axis { /// Apply reduction along columns (vertical axis). Result has 1 row. Col, /// Apply reduction along rows (horizontal axis). Result has 1 column. Row, } // --- Broadcasting --- /// A trait to turn either a `Matrix` or a scalar T into a `Vec` of /// length `rows*cols` (broadcasting the scalar). Used for comparisons. pub trait Broadcastable { fn to_vec(&self, rows: usize, cols: usize) -> Vec; } impl Broadcastable for T { fn to_vec(&self, rows: usize, cols: usize) -> Vec { vec![self.clone(); rows * cols] } } impl Broadcastable for Matrix { fn to_vec(&self, rows: usize, cols: usize) -> Vec { assert_eq!(self.rows, rows, "row count mismatch in broadcast"); assert_eq!(self.cols, cols, "col count mismatch in broadcast"); self.data.clone() // Clone the data for the broadcasted vec } } // --- Element-wise Comparisons --- /// Generates element-wise eq, lt, le, gt and ge methods /// where the rhs can be a `Matrix` or a scalar T. macro_rules! impl_elementwise_cmp { ( $( $method:ident => $op:tt ),* $(,)? ) => { impl Matrix { $( #[doc = concat!("Element-wise comparison `self ", stringify!($op), " rhs`,\n\ where `rhs` may be a `Matrix` or a scalar T.\n\ Returns a `BoolMatrix`.")] pub fn $method(&self, rhs: Rhs) -> BoolMatrix where Rhs: Broadcastable, { // Prepare broadcasted rhs-data using the trait let rhs_data = rhs.to_vec(self.rows, self.cols); // Pairwise compare using iterators let data = self .data .iter() // Borrow self's data .zip(rhs_data.iter()) // Borrow rhs's broadcasted data .map(|(a, b)| a $op b) // Perform comparison op .collect(); // Create BoolMatrix from result Matrix::::from_vec(data, self.rows, self.cols) } )* } }; } // Instantiate element-wise comparison implementations for matrices. impl_elementwise_cmp! { eq_elem => ==, ne_elem => !=, lt_elem => <, le_elem => <=, gt_elem => >, ge_elem => >=, } // --- Element-wise Arithmetic Operations (Macros generating all ownership variants) --- fn check_matrix_dims_for_ops(lhs: &Matrix, rhs: &Matrix) { assert!( lhs.rows == rhs.rows, "Row count mismatch: left has {} rows, right has {} rows", lhs.rows, rhs.rows ); assert!( lhs.cols == rhs.cols, "Column count mismatch: left has {} columns, right has {} columns", lhs.cols, rhs.cols ); } /// Generates element-wise arithmetic implementations for Matrix + Matrix macro_rules! impl_elementwise_op_matrix_all { ($OpTrait:ident, $method:ident, $op:tt) => { // &Matrix + &Matrix impl<'a, 'b, T> $OpTrait<&'b Matrix> for &'a Matrix where T: Clone + $OpTrait { type Output = Matrix; fn $method(self, rhs: &'b Matrix) -> Matrix { check_matrix_dims_for_ops(self, rhs); let data = self.data.iter().cloned().zip(rhs.data.iter().cloned()).map(|(a, b)| a $op b).collect(); Matrix { rows: self.rows, cols: self.cols, data } } } // Matrix + &Matrix (Consumes self) impl<'b, T> $OpTrait<&'b Matrix> for Matrix where T: Clone + $OpTrait { type Output = Matrix; fn $method(mut self, rhs: &'b Matrix) -> Matrix { // Make self mutable for potential in-place modification check_matrix_dims_for_ops(&self, rhs); // Modify data in place for (a, b) in self.data.iter_mut().zip(rhs.data.iter().cloned()) { *a = a.clone() $op b; // Requires T: Clone for the *a = part } // Return modified self (its data vec was consumed conceptually) self // Alternative: Collect into new Vec if T is not Clone or in-place is complex // let data = self.data.into_iter().zip(rhs.data.iter().cloned()).map(|(a, b)| a $op b).collect(); // Matrix { rows: self.rows, cols: self.cols, data } } } // &Matrix + Matrix (Consumes rhs) impl<'a, T> $OpTrait> for &'a Matrix where T: Clone + $OpTrait { type Output = Matrix; fn $method(self, mut rhs: Matrix) -> Matrix { // Make rhs mutable check_matrix_dims_for_ops(self, &rhs); // Modify rhs data in place for (a, b) in self.data.iter().cloned().zip(rhs.data.iter_mut()) { *b = a $op b.clone(); // Requires T: Clone for the *b = part } // Return modified rhs rhs // Alternative: Collect into new Vec // let data = self.data.iter().cloned().zip(rhs.data.into_iter()).map(|(a, b)| a $op b).collect(); // Matrix { rows: self.rows, cols: self.cols, data } } } // Matrix + Matrix (Consumes both) impl $OpTrait> for Matrix where T: Clone + $OpTrait { type Output = Matrix; fn $method(mut self, rhs: Matrix) -> Matrix { // Make self mutable check_matrix_dims_for_ops(&self, &rhs); // Modify self data in place for (a, b) in self.data.iter_mut().zip(rhs.data.into_iter()) { *a = a.clone() $op b; // Requires T: Clone } // Return modified self self // Alternative: Collect into new Vec // let data = self.data.into_iter().zip(rhs.data.into_iter()).map(|(a, b)| a $op b).collect(); // Matrix { rows: self.rows, cols: self.cols, data } } } }; } /// Generates element-wise arithmetic implementations for Matrix + Scalar macro_rules! impl_elementwise_op_scalar_all { ($OpTrait:ident, $method:ident, $op:tt) => { // &Matrix + Scalar impl<'a, T> $OpTrait for &'a Matrix where T: Clone + $OpTrait { type Output = Matrix; fn $method(self, rhs: T) -> Matrix { let data = self.data.iter().cloned().map(|a| a $op rhs.clone()).collect(); Matrix { rows: self.rows, cols: self.cols, data } } } // Matrix + Scalar (Consumes self) impl $OpTrait for Matrix where T: Clone + $OpTrait { type Output = Matrix; fn $method(mut self, rhs: T) -> Matrix { // Make self mutable // Modify self data in place for a in self.data.iter_mut() { *a = a.clone() $op rhs.clone(); // Requires T: Clone } // Return modified self self // Alternative: Collect into new Vec // let data = self.data.into_iter().map(|a| a $op rhs.clone()).collect(); // Matrix { rows: self.rows, cols: self.cols, data } } } // NOTE: Scalar + Matrix (e.g., 1.0 + matrix) is NOT implemented here. // It would require `impl Add> for T`, which is discouraged // for primitive types unless inside the crate defining T. }; } // Instantiate ALL combinations for arithmetic ops using the new macros impl_elementwise_op_matrix_all!(Add, add, +); impl_elementwise_op_matrix_all!(Sub, sub, -); impl_elementwise_op_matrix_all!(Mul, mul, *); // Element-wise multiplication impl_elementwise_op_matrix_all!(Div, div, /); // Element-wise division impl_elementwise_op_scalar_all!(Add, add, +); impl_elementwise_op_scalar_all!(Sub, sub, -); impl_elementwise_op_scalar_all!(Mul, mul, *); impl_elementwise_op_scalar_all!(Div, div, /); // --- Element-wise Bitwise Operations (BoolMatrix) --- macro_rules! impl_bitwise_op_all { ($OpTrait:ident, $method:ident, $op:tt) => { // &Matrix OP &Matrix impl<'a, 'b> $OpTrait<&'b Matrix> for &'a Matrix { type Output = Matrix; fn $method(self, rhs: &'b Matrix) -> Matrix { check_matrix_dims_for_ops(self, rhs); let data = self.data.iter().cloned().zip(rhs.data.iter().cloned()).map(|(a, b)| a $op b).collect(); Matrix { rows: self.rows, cols: self.cols, data } } } // Matrix OP &Matrix impl<'b> $OpTrait<&'b Matrix> for Matrix { type Output = Matrix; fn $method(mut self, rhs: &'b Matrix) -> Matrix { check_matrix_dims_for_ops(&self, rhs); for (a, b) in self.data.iter_mut().zip(rhs.data.iter()) { *a = *a $op *b; } // bool is Copy self } } // &Matrix OP Matrix impl<'a> $OpTrait> for &'a Matrix { type Output = Matrix; fn $method(self, mut rhs: Matrix) -> Matrix { check_matrix_dims_for_ops(self, &rhs); for (a, b) in self.data.iter().zip(rhs.data.iter_mut()) { *b = *a $op *b; } // bool is Copy rhs } } // Matrix OP Matrix impl $OpTrait> for Matrix { type Output = Matrix; fn $method(mut self, rhs: Matrix) -> Matrix { check_matrix_dims_for_ops(&self, &rhs); for (a, b) in self.data.iter_mut().zip(rhs.data.iter()) { *a = *a $op *b; } // bool is Copy self } } }; } // Instantiate ALL combinations for bitwise ops impl_bitwise_op_all!(BitAnd, bitand, &); impl_bitwise_op_all!(BitOr, bitor, |); impl_bitwise_op_all!(BitXor, bitxor, ^); // --- Logical Not --- // `!Matrix` (consumes matrix) impl Not for Matrix { type Output = Matrix; fn not(mut self) -> Matrix { // Take by value, make mutable for val in self.data.iter_mut() { *val = !*val; // Invert in place } self // Return the modified matrix } } // `!&Matrix` (borrows matrix, returns new matrix) impl Not for &Matrix { type Output = Matrix; fn not(self) -> Matrix { // Take by reference let data = self.data.iter().map(|&v| !v).collect(); // Create new data vec Matrix { rows: self.rows, cols: self.cols, data, } } } // --- Type Aliases --- pub type FloatMatrix = Matrix; pub type BoolMatrix = Matrix; pub type IntMatrix = Matrix; pub type StringMatrix = Matrix; // --- Unit Tests --- #[cfg(test)] mod tests { use crate::matrix::BoolOps; use super::*; // Import items from outer scope // Helper to create a 2x2 f64 matrix easily fn make_f64_matrix(a: f64, b: f64, c: f64, d: f64) -> FloatMatrix { Matrix::from_cols(vec![vec![a, c], vec![b, d]]) } // Helper to create a 2x2 bool matrix easily fn make_bool_matrix(a: bool, b: bool, c: bool, d: bool) -> BoolMatrix { Matrix::from_cols(vec![vec![a, c], vec![b, d]]) } // --- Arithmetic Tests --- #[test] fn test_add_f64() { let m1 = make_f64_matrix(1.0, 2.0, 3.0, 4.0); let m2 = make_f64_matrix(5.0, 6.0, 7.0, 8.0); let expected = make_f64_matrix(6.0, 8.0, 10.0, 12.0); assert_eq!(m1.clone() + m2.clone(), expected, "M + M"); assert_eq!(m1.clone() + &m2, expected, "M + &M"); assert_eq!(&m1 + m2.clone(), expected, "&M + M"); assert_eq!(&m1 + &m2, expected, "&M + &M"); } #[test] fn test_add_scalar_f64() { let m1 = make_f64_matrix(1.0, 2.0, 3.0, 4.0); let scalar = 10.0; let expected = make_f64_matrix(11.0, 12.0, 13.0, 14.0); assert_eq!(m1.clone() + scalar, expected, "M + S"); assert_eq!(&m1 + scalar, expected, "&M + S"); } #[test] fn test_sub_f64() { let m1 = make_f64_matrix(10.0, 20.0, 30.0, 40.0); let m2 = make_f64_matrix(1.0, 2.0, 3.0, 4.0); let expected = make_f64_matrix(9.0, 18.0, 27.0, 36.0); assert_eq!(m1.clone() - m2.clone(), expected, "M - M"); assert_eq!(m1.clone() - &m2, expected, "M - &M"); assert_eq!(&m1 - m2.clone(), expected, "&M - M"); assert_eq!(&m1 - &m2, expected, "&M - &M"); } #[test] fn test_sub_scalar_f64() { let m1 = make_f64_matrix(11.0, 12.0, 13.0, 14.0); let scalar = 10.0; let expected = make_f64_matrix(1.0, 2.0, 3.0, 4.0); assert_eq!(m1.clone() - scalar, expected, "M - S"); assert_eq!(&m1 - scalar, expected, "&M - S"); } #[test] fn test_mul_f64() { // Element-wise let m1 = make_f64_matrix(1.0, 2.0, 3.0, 4.0); let m2 = make_f64_matrix(5.0, 6.0, 7.0, 8.0); let expected = make_f64_matrix(5.0, 12.0, 21.0, 32.0); assert_eq!(m1.clone() * m2.clone(), expected, "M * M"); assert_eq!(m1.clone() * &m2, expected, "M * &M"); assert_eq!(&m1 * m2.clone(), expected, "&M * M"); assert_eq!(&m1 * &m2, expected, "&M * &M"); } #[test] fn test_mul_scalar_f64() { let m1 = make_f64_matrix(1.0, 2.0, 3.0, 4.0); let scalar = 3.0; let expected = make_f64_matrix(3.0, 6.0, 9.0, 12.0); assert_eq!(m1.clone() * scalar, expected, "M * S"); assert_eq!(&m1 * scalar, expected, "&M * S"); } #[test] fn test_div_f64() { // Element-wise let m1 = make_f64_matrix(10.0, 20.0, 30.0, 40.0); let m2 = make_f64_matrix(2.0, 5.0, 6.0, 8.0); let expected = make_f64_matrix(5.0, 4.0, 5.0, 5.0); assert_eq!(m1.clone() / m2.clone(), expected, "M / M"); assert_eq!(m1.clone() / &m2, expected, "M / &M"); assert_eq!(&m1 / m2.clone(), expected, "&M / M"); assert_eq!(&m1 / &m2, expected, "&M / &M"); } #[test] fn test_div_scalar_f64() { let m1 = make_f64_matrix(10.0, 20.0, 30.0, 40.0); let scalar = 10.0; let expected = make_f64_matrix(1.0, 2.0, 3.0, 4.0); assert_eq!(m1.clone() / scalar, expected, "M / S"); assert_eq!(&m1 / scalar, expected, "&M / S"); } #[test] fn test_chained_ops_f64() { let m = make_f64_matrix(1.0, 2.0, 3.0, 4.0); let result = (((m.clone() + 1.0) * 2.0) - 4.0) / 2.0; // Expected: // m+1 = [2, 3], [4, 5] // *2 = [4, 6], [8, 10] // -4 = [0, 2], [4, 6] // /2 = [0, 1], [2, 3] let expected = make_f64_matrix(0.0, 1.0, 2.0, 3.0); assert_eq!(result, expected); } // --- Boolean Logic Tests --- #[test] fn test_bitand_bool() { let m1 = make_bool_matrix(true, false, true, false); let m2 = make_bool_matrix(true, true, false, false); let expected = make_bool_matrix(true, false, false, false); assert_eq!(m1.clone() & m2.clone(), expected, "M & M"); assert_eq!(m1.clone() & &m2, expected, "M & &M"); assert_eq!(&m1 & m2.clone(), expected, "&M & M"); assert_eq!(&m1 & &m2, expected, "&M & &M"); } #[test] fn test_bitor_bool() { let m1 = make_bool_matrix(true, false, true, false); let m2 = make_bool_matrix(true, true, false, false); let expected = make_bool_matrix(true, true, true, false); assert_eq!(m1.clone() | m2.clone(), expected, "M | M"); assert_eq!(m1.clone() | &m2, expected, "M | &M"); assert_eq!(&m1 | m2.clone(), expected, "&M | M"); assert_eq!(&m1 | &m2, expected, "&M | &M"); } #[test] fn test_bitxor_bool() { let m1 = make_bool_matrix(true, false, true, false); let m2 = make_bool_matrix(true, true, false, false); let expected = make_bool_matrix(false, true, true, false); assert_eq!(m1.clone() ^ m2.clone(), expected, "M ^ M"); assert_eq!(m1.clone() ^ &m2, expected, "M ^ &M"); assert_eq!(&m1 ^ m2.clone(), expected, "&M ^ M"); assert_eq!(&m1 ^ &m2, expected, "&M ^ &M"); } #[test] fn test_not_bool() { let m = make_bool_matrix(true, false, true, false); let expected = make_bool_matrix(false, true, false, true); assert_eq!(!m.clone(), expected, "!M (consuming)"); assert_eq!(!&m, expected, "!&M (borrowing)"); // Check original is unchanged when using !&M let original = make_bool_matrix(true, false, true, false); let _negated_ref = !&original; assert_eq!(original, make_bool_matrix(true, false, true, false)); } // --- Comparison Tests --- #[test] fn test_comparison_eq_elem() { let m1 = make_f64_matrix(1.0, 2.0, 3.0, 4.0); let m2 = make_f64_matrix(1.0, 0.0, 3.0, 5.0); let s = 3.0; let expected_m = make_bool_matrix(true, false, true, false); let expected_s = make_bool_matrix(false, false, true, false); assert_eq!(m1.eq_elem(m2), expected_m, "eq_elem matrix"); assert_eq!(m1.eq_elem(s), expected_s, "eq_elem scalar"); } #[test] fn test_comparison_gt_elem() { let m1 = make_f64_matrix(1.0, 2.0, 3.0, 4.0); let m2 = make_f64_matrix(0.0, 3.0, 3.0, 5.0); let s = 2.5; let expected_m = make_bool_matrix(true, false, false, false); let expected_s = make_bool_matrix(false, false, true, true); assert_eq!(m1.gt_elem(m2), expected_m, "gt_elem matrix"); assert_eq!(m1.gt_elem(s), expected_s, "gt_elem scalar"); } // Add more comparison tests (lt, le, ge, ne) if desired... // --- Basic Method Tests --- #[test] fn test_indexing() { let m = make_f64_matrix(1.0, 2.0, 3.0, 4.0); assert_eq!(m[(0, 0)], 1.0); assert_eq!(m[(0, 1)], 2.0); assert_eq!(m[(1, 0)], 3.0); assert_eq!(m[(1, 1)], 4.0); assert_eq!(*m.get(1, 0), 3.0); // Test get() too } #[test] #[should_panic] fn test_index_out_of_bounds_row() { let m = make_f64_matrix(1.0, 2.0, 3.0, 4.0); let _ = m[(2, 0)]; // Row 2 is out of bounds } #[test] #[should_panic] fn test_index_out_of_bounds_col() { let m = make_f64_matrix(1.0, 2.0, 3.0, 4.0); let _ = m[(0, 2)]; // Col 2 is out of bounds } #[test] fn test_dimensions() { let m = make_f64_matrix(1.0, 2.0, 3.0, 4.0); assert_eq!(m.rows(), 2); assert_eq!(m.cols(), 2); } #[test] fn test_from_vec() { let data = vec![1.0, 3.0, 2.0, 4.0]; // Column major: [col0_row0, col0_row1, col1_row0, col1_row1] let m = Matrix::from_vec(data, 2, 2); let expected = make_f64_matrix(1.0, 2.0, 3.0, 4.0); assert_eq!(m, expected); assert_eq!(m.to_vec(), vec![1.0, 3.0, 2.0, 4.0]); } // Helper function to create a basic Matrix for testing fn static_test_matrix() -> Matrix { // Column-major data: // 1 4 7 // 2 5 8 // 3 6 9 let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9]; Matrix::from_vec(data, 3, 3) } // Another helper for a different size fn static_test_matrix_2x4() -> Matrix { // Column-major data: // 1 3 5 7 // 2 4 6 8 let data = vec![1, 2, 3, 4, 5, 6, 7, 8]; Matrix::from_vec(data, 2, 4) } #[test] fn test_from_vec_basic() { let data = vec![1, 2, 3, 4, 5, 6]; // 2 rows, 3 cols (column-major) let matrix = Matrix::from_vec(data, 2, 3); assert_eq!(matrix.rows(), 2); assert_eq!(matrix.cols(), 3); assert_eq!(matrix.data(), &[1, 2, 3, 4, 5, 6]); // Check some elements assert_eq!(matrix[(0, 0)], 1); // First row, first col assert_eq!(matrix[(1, 0)], 2); // Second row, first col assert_eq!(matrix[(0, 1)], 3); // First row, second col assert_eq!(matrix[(1, 2)], 6); // Second row, third col } #[test] #[should_panic(expected = "data length mismatch")] fn test_from_vec_wrong_length() { let data = vec![1, 2, 3, 4, 5]; // Should be 6 for 2x3 Matrix::from_vec(data, 2, 3); } #[test] #[should_panic(expected = "need at least one row")] fn test_from_vec_zero_rows() { let data = vec![1, 2, 3]; Matrix::from_vec(data, 0, 3); } #[test] #[should_panic(expected = "need at least one column")] fn test_from_vec_zero_cols() { let data = vec![1, 2, 3]; Matrix::from_vec(data, 3, 0); } #[test] fn test_from_cols_basic() { // Representing: // 1 4 7 // 2 5 8 // 3 6 9 let cols_data = vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]]; let matrix = Matrix::from_cols(cols_data); assert_eq!(matrix.rows(), 3); assert_eq!(matrix.cols(), 3); // Internal data should be column-major assert_eq!(matrix.data(), &[1, 2, 3, 4, 5, 6, 7, 8, 9]); // Check some elements assert_eq!(matrix[(0, 0)], 1); assert_eq!(matrix[(2, 0)], 3); assert_eq!(matrix[(1, 1)], 5); assert_eq!(matrix[(0, 2)], 7); } #[test] fn test_from_cols_1x1() { let cols_data = vec![vec![42]]; let matrix = Matrix::from_cols(cols_data); assert_eq!(matrix.rows(), 1); assert_eq!(matrix.cols(), 1); assert_eq!(matrix.data(), &[42]); assert_eq!(matrix[(0, 0)], 42); } #[test] #[should_panic(expected = "need at least one column")] fn test_from_cols_empty_cols() { let empty_cols: Vec> = vec![]; Matrix::from_cols(empty_cols); } #[test] #[should_panic(expected = "need at least one row")] fn test_from_cols_empty_rows() { let empty_row: Vec> = vec![vec![], vec![]]; Matrix::from_cols(empty_row); } #[test] #[should_panic(expected = "col 1 has len 2, expected 3")] fn test_from_cols_mismatched_lengths() { let cols_data = vec![vec![1, 2, 3], vec![4, 5], vec![6, 7, 8]]; Matrix::from_cols(cols_data); } #[test] fn test_getters() { let matrix = static_test_matrix(); assert_eq!(matrix.rows(), 3); assert_eq!(matrix.cols(), 3); assert_eq!(matrix.data(), &[1, 2, 3, 4, 5, 6, 7, 8, 9]); } #[test] fn test_index_and_get() { let matrix = static_test_matrix(); assert_eq!(matrix[(0, 0)], 1); assert_eq!(matrix[(1, 1)], 5); assert_eq!(matrix[(2, 2)], 9); assert_eq!(*matrix.get(0, 0), 1); assert_eq!(*matrix.get(1, 1), 5); assert_eq!(*matrix.get(2, 2), 9); } #[test] #[should_panic(expected = "index out of bounds")] fn test_index_out_of_bounds_row_alt() { let matrix = static_test_matrix(); // 3x3 let _ = matrix[(3, 0)]; } #[test] #[should_panic(expected = "index out of bounds")] fn test_index_out_of_bounds_col_alt() { let matrix = static_test_matrix(); // 3x3 let _ = matrix[(0, 3)]; } #[test] fn test_index_mut_and_get_mut() { let mut matrix = static_test_matrix(); // 3x3 matrix[(0, 0)] = 10; matrix[(1, 1)] = 20; matrix[(2, 2)] = 30; assert_eq!(matrix[(0, 0)], 10); assert_eq!(matrix[(1, 1)], 20); assert_eq!(matrix[(2, 2)], 30); *matrix.get_mut(0, 1) = 15; *matrix.get_mut(2, 1) = 25; assert_eq!(matrix[(0, 1)], 15); assert_eq!(matrix[(2, 1)], 25); // Check underlying data consistency (column-major) // Should be: // 10 15 7 // 2 20 8 // 3 25 30 assert_eq!(matrix.data(), &[10, 2, 3, 15, 20, 25, 7, 8, 30]); } #[test] #[should_panic(expected = "index out of bounds")] fn test_index_mut_out_of_bounds_row() { let mut matrix = static_test_matrix(); // 3x3 matrix[(3, 0)] = 99; } #[test] #[should_panic(expected = "index out of bounds")] fn test_index_mut_out_of_bounds_col() { let mut matrix = static_test_matrix(); // 3x3 matrix[(0, 3)] = 99; } #[test] fn test_column() { let matrix = static_test_matrix_2x4(); // 2x4 // 1 3 5 7 // 2 4 6 8 assert_eq!(matrix.column(0), &[1, 2]); assert_eq!(matrix.column(1), &[3, 4]); assert_eq!(matrix.column(2), &[5, 6]); assert_eq!(matrix.column(3), &[7, 8]); } #[test] #[should_panic(expected = "column index 4 out of bounds for 4 columns")] fn test_column_out_of_bounds() { let matrix = static_test_matrix_2x4(); // 2x4 matrix.column(4); } #[test] fn test_column_mut() { let mut matrix = static_test_matrix_2x4(); // 2x4 // 1 3 5 7 // 2 4 6 8 let col1_mut = matrix.column_mut(1); col1_mut[0] = 30; col1_mut[1] = 40; let col3_mut = matrix.column_mut(3); col3_mut[0] = 70; // Check changes via indexing assert_eq!(matrix[(0, 1)], 30); assert_eq!(matrix[(1, 1)], 40); assert_eq!(matrix[(0, 3)], 70); assert_eq!(matrix[(1, 3)], 8); // Unchanged // Check underlying data (column-major) // Should be: // 1 30 5 70 // 2 40 6 8 assert_eq!(matrix.data(), &[1, 2, 30, 40, 5, 6, 70, 8]); } #[test] #[should_panic(expected = "column index 4 out of bounds for 4 columns")] fn test_column_mut_out_of_bounds() { let mut matrix = static_test_matrix_2x4(); // 2x4 matrix.column_mut(4); } #[test] fn test_iter_columns() { let matrix = static_test_matrix_2x4(); // 2x4 // 1 3 5 7 // 2 4 6 8 let cols: Vec<&[i32]> = matrix.iter_columns().collect(); assert_eq!(cols.len(), 4); assert_eq!(cols[0], &[1, 2]); assert_eq!(cols[1], &[3, 4]); assert_eq!(cols[2], &[5, 6]); assert_eq!(cols[3], &[7, 8]); } #[test] fn test_iter_rows() { let matrix = static_test_matrix_2x4(); // 2x4 // 1 3 5 7 // 2 4 6 8 let rows: Vec> = matrix .iter_rows() .map(|row| row.iter().cloned().collect()) .collect(); assert_eq!(rows.len(), 2); assert_eq!(rows[0], vec![1, 3, 5, 7]); assert_eq!(rows[1], vec![2, 4, 6, 8]); } // test data_mut #[test] fn test_data_mut() { let mut matrix = static_test_matrix(); // 3x3 // 1 4 7 // 2 5 8 // 3 6 9 let data_mut = matrix.data_mut(); data_mut[0] = 10; data_mut[1] = 20; assert_eq!(matrix[(0, 0)], 10); assert_eq!(matrix[(1, 0)], 20); } #[test] fn test_matrix_row_get_and_iter() { let matrix = static_test_matrix_2x4(); // 2x4 // 1 3 5 7 // 2 4 6 8 let row0 = matrix.iter_rows().next().unwrap(); assert_eq!(*row0.get(0), 1); assert_eq!(*row0.get(1), 3); assert_eq!(*row0.get(3), 7); let row0_vec: Vec = row0.iter().cloned().collect(); assert_eq!(row0_vec, vec![1, 3, 5, 7]); let row1 = matrix.iter_rows().nth(1).unwrap(); assert_eq!(*row1.get(0), 2); assert_eq!(*row1.get(2), 6); let row1_vec: Vec = row1.iter().cloned().collect(); assert_eq!(row1_vec, vec![2, 4, 6, 8]); } #[test] fn test_swap_columns() { let mut matrix = static_test_matrix(); // 3x3 // 1 4 7 // 2 5 8 // 3 6 9 matrix.swap_columns(0, 2); // Swap first and last // Should be: // 7 4 1 // 8 5 2 // 9 6 3 assert_eq!(matrix.rows(), 3); assert_eq!(matrix.cols(), 3); assert_eq!(matrix[(0, 0)], 7); assert_eq!(matrix[(1, 0)], 8); assert_eq!(matrix[(2, 0)], 9); assert_eq!(matrix[(0, 1)], 4); // Middle col unchanged assert_eq!(matrix[(1, 1)], 5); assert_eq!(matrix[(2, 1)], 6); assert_eq!(matrix[(0, 2)], 1); assert_eq!(matrix[(1, 2)], 2); assert_eq!(matrix[(2, 2)], 3); // Swap the same column (should do nothing) let original_data = matrix.data().to_vec(); matrix.swap_columns(1, 1); assert_eq!(matrix.data(), &original_data); // Data should be identical // Check underlying data (column-major) after swap(0, 2) assert_eq!(matrix.data(), &[7, 8, 9, 4, 5, 6, 1, 2, 3]); } #[test] #[should_panic(expected = "column index c2=3 out of bounds for 3 columns")] fn test_swap_columns_out_of_bounds() { let mut matrix = static_test_matrix(); // 3x3 matrix.swap_columns(0, 3); } #[test] fn test_delete_column() { let mut matrix = static_test_matrix_2x4(); // 2x4 // 1 3 5 7 // 2 4 6 8 matrix.delete_column(1); // Delete the second column (index 1) // Should be: // 1 5 7 // 2 6 8 assert_eq!(matrix.rows(), 2); assert_eq!(matrix.cols(), 3); assert_eq!(matrix[(0, 0)], 1); assert_eq!(matrix[(1, 0)], 2); assert_eq!(matrix[(0, 1)], 5); assert_eq!(matrix[(1, 1)], 6); assert_eq!(matrix[(0, 2)], 7); assert_eq!(matrix[(1, 2)], 8); // Check underlying data (column-major) assert_eq!(matrix.data(), &[1, 2, 5, 6, 7, 8]); // Delete the first column matrix.delete_column(0); // Should be: // 5 7 // 6 8 assert_eq!(matrix.rows(), 2); assert_eq!(matrix.cols(), 2); assert_eq!(matrix.data(), &[5, 6, 7, 8]); // Delete the last column matrix.delete_column(1); // Should be: // 5 // 6 assert_eq!(matrix.rows(), 2); assert_eq!(matrix.cols(), 1); assert_eq!(matrix.data(), &[5, 6]); // Delete the only column matrix.delete_column(0); // Should be empty assert_eq!(matrix.rows(), 2); // Rows stay the same assert_eq!(matrix.cols(), 0); // Cols becomes 0 assert_eq!(matrix.data(), &[]); } #[test] #[should_panic(expected = "column index 4 out of bounds for 4 columns")] fn test_delete_column_out_of_bounds() { let mut matrix = static_test_matrix_2x4(); // 2x4 matrix.delete_column(4); } #[test] fn test_delete_row() { let mut matrix = static_test_matrix(); // 3x3 // 1 4 7 // 2 5 8 // 3 6 9 matrix.delete_row(1); // Delete the second row (index 1) // Should be: // 1 4 7 // 3 6 9 assert_eq!(matrix.rows(), 2); assert_eq!(matrix.cols(), 3); assert_eq!(matrix[(0, 0)], 1); assert_eq!(matrix[(1, 0)], 3); assert_eq!(matrix[(0, 1)], 4); assert_eq!(matrix[(1, 1)], 6); assert_eq!(matrix[(0, 2)], 7); assert_eq!(matrix[(1, 2)], 9); // Check underlying data (column-major) // Original: [1, 2, 3, 4, 5, 6, 7, 8, 9] // Delete row 1: [1, 3, 4, 6, 7, 9] assert_eq!(matrix.data(), &[1, 3, 4, 6, 7, 9]); // Delete the first row matrix.delete_row(0); // Should be: // 3 6 9 assert_eq!(matrix.rows(), 1); assert_eq!(matrix.cols(), 3); assert_eq!(matrix.data(), &[3, 6, 9]); // Delete the last (and only) row matrix.delete_row(0); // Should be empty assert_eq!(matrix.rows(), 0); // Rows becomes 0 assert_eq!(matrix.cols(), 3); // Cols stay the same assert_eq!(matrix.data(), &[]); } #[test] #[should_panic(expected = "row index 3 out of bounds for 3 rows")] fn test_delete_row_out_of_bounds() { let mut matrix = static_test_matrix(); // 3x3 matrix.delete_row(3); } #[test] fn test_add_column() { let mut matrix = static_test_matrix_2x4(); // 2x4 // 1 3 5 7 // 2 4 6 8 let new_col = vec![9, 10]; matrix.add_column(2, new_col); // Add at index 2 // Should be: // 1 3 9 5 7 // 2 4 10 6 8 assert_eq!(matrix.rows(), 2); assert_eq!(matrix.cols(), 5); assert_eq!(matrix[(0, 0)], 1); assert_eq!(matrix[(1, 0)], 2); assert_eq!(matrix[(0, 1)], 3); assert_eq!(matrix[(1, 1)], 4); assert_eq!(matrix[(0, 2)], 9); assert_eq!(matrix[(1, 2)], 10); assert_eq!(matrix[(0, 3)], 5); // Shifted assert_eq!(matrix[(1, 3)], 6); assert_eq!(matrix[(0, 4)], 7); // Shifted assert_eq!(matrix[(1, 4)], 8); // Check underlying data (column-major) // Original: [1, 2, 3, 4, 5, 6, 7, 8] // Add [9, 10] at index 2: [1, 2, 3, 4, 9, 10, 5, 6, 7, 8] assert_eq!(matrix.data(), &[1, 2, 3, 4, 9, 10, 5, 6, 7, 8]); // Add a column at the beginning let new_col_start = vec![11, 12]; matrix.add_column(0, new_col_start); // Should be: // 11 1 3 9 5 7 // 12 2 4 10 6 8 assert_eq!(matrix.rows(), 2); assert_eq!(matrix.cols(), 6); assert_eq!(matrix[(0, 0)], 11); assert_eq!(matrix[(1, 0)], 12); assert_eq!(matrix.data(), &[11, 12, 1, 2, 3, 4, 9, 10, 5, 6, 7, 8]); // Add a column at the end let new_col_end = vec![13, 14]; matrix.add_column(6, new_col_end); // Should be: // 11 1 3 9 5 7 13 // 12 2 4 10 6 8 14 assert_eq!(matrix.rows(), 2); assert_eq!(matrix.cols(), 7); assert_eq!(matrix[(0, 6)], 13); assert_eq!(matrix[(1, 6)], 14); assert_eq!( matrix.data(), &[11, 12, 1, 2, 3, 4, 9, 10, 5, 6, 7, 8, 13, 14] ); } #[test] #[should_panic(expected = "add_column index 5 out of bounds for 4 columns")] fn test_add_column_out_of_bounds() { let mut matrix = static_test_matrix_2x4(); // 2x4 let new_col = vec![9, 10]; matrix.add_column(5, new_col); // Index 5 is out of bounds for 4 columns } #[test] #[should_panic(expected = "column length mismatch")] fn test_add_column_length_mismatch() { let mut matrix = static_test_matrix_2x4(); // 2x4 (2 rows) let new_col = vec![9, 10, 11]; // Wrong length matrix.add_column(0, new_col); } #[test] fn test_add_row() { let mut matrix = static_test_matrix_2x4(); // 2x4 // 1 3 5 7 // 2 4 6 8 let new_row = vec![9, 10, 11, 12]; matrix.add_row(1, new_row); // Add at index 1 // Should be: // 1 3 5 7 // 9 10 11 12 // 2 4 6 8 assert_eq!(matrix.rows(), 3); assert_eq!(matrix.cols(), 4); assert_eq!(matrix[(0, 0)], 1); assert_eq!(matrix[(0, 1)], 3); assert_eq!(matrix[(0, 2)], 5); assert_eq!(matrix[(0, 3)], 7); assert_eq!(matrix[(1, 0)], 9); assert_eq!(matrix[(1, 1)], 10); assert_eq!(matrix[(1, 2)], 11); assert_eq!(matrix[(1, 3)], 12); assert_eq!(matrix[(2, 0)], 2); assert_eq!(matrix[(2, 1)], 4); assert_eq!(matrix[(2, 2)], 6); assert_eq!(matrix[(2, 3)], 8); // Check underlying data (column-major) // Original: [1, 2, 3, 4, 5, 6, 7, 8] (rows 0, 1) // Add [9, 10, 11, 12] at index 1 (new row will be index 1, original row 1 becomes index 2) // Col 0: [1, 9, 2] // Col 1: [3, 10, 4] // Col 2: [5, 11, 6] // Col 3: [7, 12, 8] // Data: [1, 9, 2, 3, 10, 4, 5, 11, 6, 7, 12, 8] assert_eq!(matrix.data(), &[1, 9, 2, 3, 10, 4, 5, 11, 6, 7, 12, 8]); // Add a row at the beginning let new_row_start = vec![13, 14, 15, 16]; matrix.add_row(0, new_row_start); // Should be: // 13 14 15 16 // 1 3 5 7 // 9 10 11 12 // 2 4 6 8 assert_eq!(matrix.rows(), 4); assert_eq!(matrix.cols(), 4); assert_eq!(matrix[(0, 0)], 13); assert_eq!(matrix[(0, 1)], 14); assert_eq!(matrix[(0, 2)], 15); assert_eq!(matrix[(0, 3)], 16); assert_eq!(matrix[(1, 0)], 1); assert_eq!(matrix[(2, 1)], 10); assert_eq!(matrix[(3, 3)], 8); // Add a row at the end let new_row_end = vec![17, 18, 19, 20]; matrix.add_row(4, new_row_end); // Should be: // 13 14 15 16 // 1 3 5 7 // 9 10 11 12 // 2 4 6 8 // 17 18 19 20 assert_eq!(matrix.rows(), 5); assert_eq!(matrix.cols(), 4); assert_eq!(matrix[(4, 0)], 17); assert_eq!(matrix[(4, 3)], 20); } #[test] #[should_panic(expected = "add_row index 3 out of bounds for 2 rows")] fn test_add_row_out_of_bounds() { let mut matrix = static_test_matrix_2x4(); // 2x4 let new_row = vec![9, 10, 11, 12]; matrix.add_row(3, new_row); // Index 3 is out of bounds for 2 rows } #[test] #[should_panic(expected = "row length mismatch")] fn test_add_row_length_mismatch() { let mut matrix = static_test_matrix_2x4(); // 2x4 (4 cols) let new_row = vec![9, 10, 11]; // Wrong length matrix.add_row(0, new_row); } #[test] fn test_elementwise_add() { let matrix1 = static_test_matrix(); // 3x3 let matrix2 = Matrix::from_vec(vec![9, 8, 7, 6, 5, 4, 3, 2, 1], 3, 3); // 3x3 let result = &matrix1 + &matrix2; assert_eq!(result.rows(), 3); assert_eq!(result.cols(), 3); // Expected: // 1+9 4+6 7+3 => 10 10 10 // 2+8 5+5 8+2 => 10 10 10 // 3+7 6+4 9+1 => 10 10 10 // Column-major data: [10, 10, 10, 10, 10, 10, 10, 10, 10] assert_eq!(result.data(), &[10, 10, 10, 10, 10, 10, 10, 10, 10]); assert_eq!(result[(0, 0)], 10); assert_eq!(result[(1, 1)], 10); assert_eq!(result[(2, 2)], 10); } #[test] fn test_elementwise_sub() { let matrix1 = static_test_matrix(); // 3x3 let matrix2 = Matrix::from_vec(vec![1, 1, 1, 2, 2, 2, 3, 3, 3], 3, 3); // 3x3 let result = &matrix1 - &matrix2; assert_eq!(result.rows(), 3); assert_eq!(result.cols(), 3); // Expected: // 1-1 4-2 7-3 => 0 2 4 // 2-1 5-2 8-3 => 1 3 5 // 3-1 6-2 9-3 => 2 4 6 // Column-major data: [0, 1, 2, 2, 3, 4, 4, 5, 6] assert_eq!(result.data(), &[0, 1, 2, 2, 3, 4, 4, 5, 6]); assert_eq!(result[(0, 0)], 0); assert_eq!(result[(1, 1)], 3); assert_eq!(result[(2, 2)], 6); } #[test] fn test_elementwise_mul() { let matrix1 = static_test_matrix(); // 3x3 let matrix2 = Matrix::from_vec(vec![1, 2, 3, 1, 2, 3, 1, 2, 3], 3, 3); // 3x3 let result = &matrix1 * &matrix2; assert_eq!(result.rows(), 3); assert_eq!(result.cols(), 3); // Expected: // 1*1 4*1 7*1 => 1 4 7 // 2*2 5*2 8*2 => 4 10 16 // 3*3 6*3 9*3 => 9 18 27 // Column-major data: [1, 4, 9, 4, 10, 18, 7, 16, 27] assert_eq!(result.data(), &[1, 4, 9, 4, 10, 18, 7, 16, 27]); assert_eq!(result[(0, 0)], 1); assert_eq!(result[(1, 1)], 10); assert_eq!(result[(2, 2)], 27); } #[test] fn test_elementwise_div() { let matrix1 = static_test_matrix(); // 3x3 let matrix2 = Matrix::from_vec(vec![1, 1, 1, 2, 2, 2, 7, 8, 9], 3, 3); // 3x3 let result = &matrix1 / &matrix2; // Integer division assert_eq!(result.rows(), 3); assert_eq!(result.cols(), 3); // Expected: // 1/1 4/2 7/7 => 1 2 1 // 2/1 5/2 8/8 => 2 2 1 (integer division) // 3/1 6/2 9/9 => 3 3 1 // Column-major data: [1, 2, 3, 2, 2, 3, 1, 1, 1] assert_eq!(result.data(), &[1, 2, 3, 2, 2, 3, 1, 1, 1]); assert_eq!(result[(0, 0)], 1); assert_eq!(result[(1, 1)], 2); assert_eq!(result[(2, 2)], 1); } #[test] #[should_panic(expected = "Row count mismatch: left has 3 rows, right has 2 rows")] fn test_elementwise_op_row_mismatch() { let matrix1 = static_test_matrix(); // 3x3 let matrix2 = static_test_matrix_2x4(); // 2x4 let _ = &matrix1 + &matrix2; // Should panic } #[test] #[should_panic(expected = "Row count mismatch: left has 3 rows, right has 2 ro")] fn test_elementwise_op_col_mismatch() { let matrix1 = static_test_matrix(); // 3x3 let matrix2 = static_test_matrix_2x4(); // 2x4 let _ = &matrix1 * &matrix2; // Should panic } #[test] fn test_bitwise_and() { let data1 = vec![true, false, true, false, true, false]; // 2x3 let data2 = vec![true, true, false, false, true, true]; // 2x3 let matrix1 = BoolMatrix::from_vec(data1, 2, 3); let matrix2 = BoolMatrix::from_vec(data2, 2, 3); // Expected column-major results: // T & T = T // F & T = F // T & F = F // F & F = F // T & T = T // F & T = F // Data: [T, F, F, F, T, F] let expected_data = vec![true, false, false, false, true, false]; let expected_matrix = BoolMatrix::from_vec(expected_data, 2, 3); let result = &matrix1 & &matrix2; assert_eq!(result, expected_matrix); } #[test] fn test_bitwise_or() { let data1 = vec![true, false, true, false, true, false]; // 2x3 let data2 = vec![true, true, false, false, true, true]; // 2x3 let matrix1 = BoolMatrix::from_vec(data1, 2, 3); let matrix2 = BoolMatrix::from_vec(data2, 2, 3); // Expected column-major results: // T | T = T // F | T = T // T | F = T // F | F = F // T | T = T // F | T = T // Data: [T, T, T, F, T, T] let expected_data = vec![true, true, true, false, true, true]; let expected_matrix = BoolMatrix::from_vec(expected_data, 2, 3); let result = &matrix1 | &matrix2; assert_eq!(result, expected_matrix); } #[test] fn test_bitwise_xor() { let data1 = vec![true, false, true, false, true, false]; // 2x3 let data2 = vec![true, true, false, false, true, true]; // 2x3 let matrix1 = BoolMatrix::from_vec(data1, 2, 3); let matrix2 = BoolMatrix::from_vec(data2, 2, 3); // Expected column-major results: // T ^ T = F // F ^ T = T // T ^ F = T // F ^ F = F // T ^ T = F // F ^ T = T // Data: [F, T, T, F, F, T] let expected_data = vec![false, true, true, false, false, true]; let expected_matrix = BoolMatrix::from_vec(expected_data, 2, 3); let result = &matrix1 ^ &matrix2; assert_eq!(result, expected_matrix); } #[test] fn test_bitwise_not() { let data = vec![true, false, true, false, true, false]; // 2x3 let matrix = BoolMatrix::from_vec(data, 2, 3); // Expected column-major results: // !T = F // !F = T // !T = F // !F = T // !T = F // !F = T // Data: [F, T, F, T, F, T] let expected_data = vec![false, true, false, true, false, true]; let expected_matrix = BoolMatrix::from_vec(expected_data, 2, 3); let result = !matrix; // Not consumes the matrix assert_eq!(result, expected_matrix); } #[test] #[should_panic(expected = "Column count mismatch: left has 2 columns, right has 3 columns")] fn test_bitwise_op_row_mismatch() { let data1 = vec![true, false, true, false]; // 2x2 let data2 = vec![true, true, false, false, true, true]; // 2x3 let matrix1 = BoolMatrix::from_vec(data1, 2, 2); let matrix2 = BoolMatrix::from_vec(data2, 2, 3); let _ = &matrix1 & &matrix2; // Should panic } #[test] #[should_panic(expected = "Column count mismatch: left has 2 columns, right has 3 columns")] fn test_bitwise_op_col_mismatch() { let data1 = vec![true, false, true, false]; // 2x2 let data2 = vec![true, true, false, false, true, true]; // 2x3 let matrix1 = BoolMatrix::from_vec(data1, 2, 2); let matrix2 = BoolMatrix::from_vec(data2, 2, 3); let _ = &matrix1 | &matrix2; // Should panic } // Test with String type (requires Clone, PartialEq) #[test] fn test_string_matrix() { let data = vec![ "a".to_string(), "b".to_string(), "c".to_string(), "d".to_string(), ]; let matrix = StringMatrix::from_vec(data.clone(), 2, 2); // 2x2 assert_eq!(matrix[(0, 0)], "a".to_string()); assert_eq!(matrix[(1, 0)], "b".to_string()); assert_eq!(matrix[(0, 1)], "c".to_string()); assert_eq!(matrix[(1, 1)], "d".to_string()); // Test modification let mut matrix = matrix; matrix[(0, 0)] = "hello".to_string(); assert_eq!(matrix[(0, 0)], "hello".to_string()); // Test add_column (requires Clone) let new_col = vec!["e".to_string(), "f".to_string()]; matrix.add_column(1, new_col); // Add at index 1 // Should be: // hello c d // b e f assert_eq!(matrix.rows(), 2); assert_eq!(matrix.cols(), 3); assert_eq!(matrix[(0, 0)], "hello".to_string()); assert_eq!(matrix[(1, 0)], "b".to_string()); assert_eq!(matrix[(0, 1)], "e".to_string()); // New col assert_eq!(matrix[(1, 1)], "f".to_string()); // New col assert_eq!(matrix[(0, 2)], "c".to_string()); // Shifted assert_eq!(matrix[(1, 2)], "d".to_string()); // Shifted // Test add_row (requires Clone) let new_row = vec!["g".to_string(), "h".to_string(), "i".to_string()]; matrix.add_row(0, new_row); // Add at index 0 // Should be: // g h i // hello e c // b f d assert_eq!(matrix.rows(), 3); assert_eq!(matrix.cols(), 3); assert_eq!(matrix[(0, 0)], "g".to_string()); assert_eq!(matrix[(0, 1)], "h".to_string()); assert_eq!(matrix[(0, 2)], "i".to_string()); assert_eq!(matrix[(1, 0)], "hello".to_string()); // Shifted assert_eq!(matrix[(2, 2)], "d".to_string()); // Shifted } #[test] fn test_float_matrix_ops() { let data1 = vec![1.0, 2.0, 3.0, 4.0]; // 2x2 let data2 = vec![0.5, 1.5, 2.5, 3.5]; // 2x2 let matrix1 = FloatMatrix::from_vec(data1, 2, 2); let matrix2 = FloatMatrix::from_vec(data2, 2, 2); let sum = &matrix1 + &matrix2; let diff = &matrix1 - &matrix2; let prod = &matrix1 * &matrix2; let div = &matrix1 / &matrix2; // Check sums (col-major): [1.5, 3.5, 5.5, 7.5] assert_eq!(sum.data(), &[1.5, 3.5, 5.5, 7.5]); // Check diffs (col-major): [0.5, 0.5, 0.5, 0.5] assert_eq!(diff.data(), &[0.5, 0.5, 0.5, 0.5]); // Check prods (col-major): [0.5, 3.0, 7.5, 14.0] assert_eq!(prod.data(), &[0.5, 3.0, 7.5, 14.0]); // Check divs (col-major): [2.0, 1.333..., 1.2, 1.14...] // Using element access for more specific checks on floating point results assert_eq!(div.rows(), 2); assert_eq!(div.cols(), 2); assert!((div[(0, 0)] - 1.0 / 0.5).abs() < 1e-9); // 2.0 assert!((div[(1, 0)] - 2.0 / 1.5).abs() < 1e-9); // 1.333... assert!((div[(0, 1)] - 3.0 / 2.5).abs() < 1e-9); // 1.2 assert!((div[(1, 1)] - 4.0 / 3.5).abs() < 1e-9); // 1.14... } fn create_test_matrix_i32() -> Matrix { Matrix::from_cols(vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]]) } #[test] fn test_matrix_swap_columns_directly() { let mut matrix = create_test_matrix_i32(); // Store the initial state of the columns we intend to swap AND one that shouldn't change let initial_col0_data = matrix.column(0).to_vec(); // Should be [1, 2, 3] let initial_col1_data = matrix.column(1).to_vec(); // Should be [4, 5, 6] let initial_col2_data = matrix.column(2).to_vec(); // Should be [7, 8, 9] // Perform the swap directly on the matrix matrix.swap_columns(0, 2); // Swap column 0 and column 2 // --- Assertions --- // 1. Verify the dimensions are unchanged assert_eq!(matrix.rows(), 3, "Matrix rows should remain unchanged"); assert_eq!(matrix.cols(), 3, "Matrix cols should remain unchanged"); // 2. Verify the column that was NOT swapped is unchanged assert_eq!( matrix.column(1), initial_col1_data.as_slice(), // Comparing slice to slice "Column 1 data should be unchanged" ); // 3. Verify the data swap occurred correctly using the COLUMN ACCESSOR // The data originally at index 0 should now be at index 2 assert_eq!( matrix.column(2), initial_col0_data.as_slice(), "Column 2 should now contain the original data from column 0" ); // The data originally at index 2 should now be at index 0 assert_eq!( matrix.column(0), initial_col2_data.as_slice(), "Column 0 should now contain the original data from column 2" ); // 4. (Optional but useful) Verify the underlying raw data vector // Original data: [1, 2, 3, 4, 5, 6, 7, 8, 9] // Expected data after swapping col 0 and col 2: [7, 8, 9, 4, 5, 6, 1, 2, 3] assert_eq!( matrix.data(), &[7, 8, 9, 4, 5, 6, 1, 2, 3], "Underlying data vector is incorrect after swap" ); // 5. Test swapping with self (should be a no-op) let state_before_self_swap = matrix.clone(); matrix.swap_columns(1, 1); assert_eq!( matrix, state_before_self_swap, "Swapping a column with itself should not change the matrix" ); // 6. Test swapping adjacent columns let mut matrix2 = create_test_matrix_i32(); let initial_col0_data_m2 = matrix2.column(0).to_vec(); let initial_col1_data_m2 = matrix2.column(1).to_vec(); matrix2.swap_columns(0, 1); assert_eq!(matrix2.column(0), initial_col1_data_m2.as_slice()); assert_eq!(matrix2.column(1), initial_col0_data_m2.as_slice()); assert_eq!(matrix2.data(), &[4, 5, 6, 1, 2, 3, 7, 8, 9]); } // Test broadcastable operations #[test] fn test_comparision_broadcast() { let matrix = static_test_matrix(); // test all > 0 let result = matrix.gt_elem(0).into_vec(); let expected = vec![true; result.len()]; assert_eq!(result, expected); let ma = static_test_matrix(); let mb = static_test_matrix(); let result = ma.eq_elem(mb); assert!(result.all()); let result = matrix.lt_elem(1e10 as i32).all(); assert!(result); for i in 0..matrix.rows() { for j in 0..matrix.cols() { let vx = matrix[(i, j)]; let c = &(matrix.le_elem(vx)) & &(matrix.ge_elem(vx)); assert_eq!(c.count(), 1); } } } #[test] fn test_arithmetic_broadcast() { let matrix = static_test_matrix(); let result = &matrix + 1; for i in 0..matrix.rows() { for j in 0..matrix.cols() { assert_eq!(result[(i, j)], matrix[(i, j)] + 1); } } // test mul and div let result = &matrix * 2; let result2 = &matrix / 2; for i in 0..matrix.rows() { for j in 0..matrix.cols() { assert_eq!(result[(i, j)], matrix[(i, j)] * 2); assert_eq!(result2[(i, j)], matrix[(i, j)] / 2); } } } }