Compare commits

...

39 Commits

Author SHA1 Message Date
Palash Tyagi
e9cda2b8c7
Merge 58b38311b50ec33d1dfe96b0c91e407be2917372 into f44bb5b2053c6303a466c966dde57f3d130b67f9 2025-07-06 01:04:20 +01:00
58b38311b5
Merge branch 'main' into dataframe 2025-07-06 01:04:19 +01:00
f44bb5b205
Merge pull request #54 from Magnus167/fix-benchmark
Update benchmarks and fix README formatting
2025-07-06 01:00:48 +01:00
Palash Tyagi
fde9c73a66 Fix README formatting by commenting out logo image line 2025-07-06 00:53:31 +01:00
Palash Tyagi
85e0eb7e67 Update benchmarks to use Axis enum for apply_axis method 2025-07-06 00:53:27 +01:00
4ed23069fc
Merge branch 'main' into dataframe 2025-07-06 00:47:15 +01:00
b8e64811ed
Merge pull request #53 from Magnus167/matrix-ops-enhancement
Updates to matrix operations for SeriesOps trait
2025-07-06 00:46:09 +01:00
Palash Tyagi
2ac2db258f Add tests for matrix multiplication and transpose operations in Frame 2025-07-06 00:38:22 +01:00
Palash Tyagi
7a68d13eb3 Change transpose method to take an immutable reference 2025-07-06 00:38:14 +01:00
Palash Tyagi
f39c678192 Add test for matrix multiplication error handling and create 4x4 test matrix with NaNs 2025-07-06 00:38:07 +01:00
Palash Tyagi
ad9f89860e Add benchmarks for matrix and frame operations including multiplication, summation, product, application along axes, and transposition 2025-07-06 00:23:03 +01:00
Palash Tyagi
ef574892fa Add transpose method to Frame for matrix transposition 2025-07-06 00:22:41 +01:00
Palash Tyagi
30bff6ecf4 Add transpose method to Frame for matrix transposition 2025-07-06 00:22:10 +01:00
Palash Tyagi
9daf583a4d Implement matrix multiplication and dot product methods for SeriesOps in Frame<f64> 2025-07-05 21:07:46 +01:00
Palash Tyagi
1ebc3546d2 Refactor assertions for clarity in Matrix methods and update test cases for consistency 2025-07-05 21:07:32 +01:00
Palash Tyagi
ffe635f1c4 Add transpose method and corresponding tests for Matrix 2025-07-05 21:07:21 +01:00
Palash Tyagi
75f194b8c9 Implement matrix multiplication and dot product methods for FloatMatrix in SeriesOps trait 2025-07-05 20:44:16 +01:00
Palash Tyagi
7d7794627b Refactor DataFrame usage example in README.md for clarity and consistency 2025-07-04 20:15:47 +01:00
d9bdf8ee96
Merge branch 'main' into dataframe 2025-07-04 00:59:57 +01:00
a61ff8a4e1
Merge branch 'main' into dataframe 2025-07-04 00:55:16 +01:00
Palash Tyagi
26ee580710 Refactor README: update DataFrame usage example 2025-07-04 00:46:12 +01:00
Palash Tyagi
96934cd89f update DataFrame module exports 2025-07-04 00:45:45 +01:00
Palash Tyagi
27ab1ac129 reimplement dataframe functionality from scratch 2025-07-04 00:45:28 +01:00
Palash Tyagi
eb4fefe363 Enhance DataFrame display: implement column ellipsis for large datasets; improve row and column index calculations for better output formatting. 2025-07-02 23:45:43 +01:00
Palash Tyagi
60cc97e702 Enhance DataFrame display: implement row truncation with ellipsis for large datasets; improve column width calculations and formatting for better readability. 2025-07-02 23:33:34 +01:00
Palash Tyagi
7e2a5ec18d Enhance DataFrame display: update head and tail methods for improved row retrieval and formatting; refine display output for empty DataFrames and adjust column width calculations. 2025-07-02 22:18:09 +01:00
Palash Tyagi
4038d25b07 applied formatting 2025-07-02 00:25:45 +01:00
Palash Tyagi
aa15248b58 Rename variable for clarity in DataFrame display formatting 2025-07-02 00:25:31 +01:00
Palash Tyagi
fa392ec631 Add head_n and tail_n methods to DataFrame for row retrieval; enhance display formatting 2025-07-02 00:22:52 +01:00
Palash Tyagi
8b6f16236a Refactor TypedFrame methods using macros for common functionality and improve column accessors 2025-07-01 23:26:57 +01:00
Palash Tyagi
58acea8467 Add DataFrame usage examples to README.md 2025-06-22 21:16:06 +01:00
Palash Tyagi
2607d9c3b0 Add pub use statement for DataFrame, DataFrameColumn, and TypedFrame in mod.rs 2025-06-22 21:15:12 +01:00
Palash Tyagi
57ed06f79b Reimplemented dataframe class with TypedFrame interface 2025-06-22 19:47:12 +01:00
Palash Tyagi
01a132264f Remove unused imports and clean up test module in DataFrame implementation 2025-06-22 05:44:24 +01:00
Palash Tyagi
ff4535c56b Implement column renaming in DataFrame, updating both logical names and underlying Frame references. 2025-06-22 05:35:48 +01:00
9b480e8130
Merge branch 'main' into dataframe 2025-06-22 05:22:06 +01:00
Palash Tyagi
fe666a4ddb First draft: Implement DataFrame and DataFrameColumn structures 2025-06-22 05:01:19 +01:00
Palash Tyagi
b80d5ab381 Add documentation for the DataFrame module and include it in the library 2025-06-22 05:00:59 +01:00
Palash Tyagi
49f7558225 Enhance column access methods to clarify usage by name and physical index 2025-06-22 05:00:42 +01:00
9 changed files with 1125 additions and 240 deletions

130
README.md
View File

@ -1,5 +1,6 @@
# rustframe
# <img align="center" alt="Rustframe" src=".github/rustframe_logo.png" height="50px" /> rustframe <!-- # <img align="center" alt="Rustframe" src=".github/rustframe_logo.png" height="50px" /> rustframe -->
<!-- though the centre tag doesn't work as it would noramlly, it achieves the desired effect --> <!-- though the centre tag doesn't work as it would noramlly, it achieves the desired effect -->
@ -100,6 +101,133 @@ assert!(check);
``` ```
---
## DataFrame Usage Example
```rust
use chrono::NaiveDate;
use rustframe::dataframe::DataFrame;
use rustframe::utils::{BDateFreq, BDatesList};
use std::any::TypeId;
use std::collections::HashMap;
// Helper for NaiveDate
fn d(y: i32, m: u32, d: u32) -> NaiveDate {
NaiveDate::from_ymd_opt(y, m, d).unwrap()
}
// Create a new DataFrame
let mut df = DataFrame::new();
// Add columns of different types
df.add_column("col_int1", vec![1, 2, 3, 4, 5]);
df.add_column("col_float1", vec![1.1, 2.2, 3.3, 4.4, 5.5]);
df.add_column(
"col_string",
vec![
"apple".to_string(),
"banana".to_string(),
"cherry".to_string(),
"date".to_string(),
"elderberry".to_string(),
],
);
df.add_column("col_bool", vec![true, false, true, false, true]);
// df.add_column("col_date", vec![d(2023,1,1), d(2023,1,2), d(2023,1,3), d(2023,1,4), d(2023,1,5)]);
df.add_column(
"col_date",
BDatesList::from_n_periods("2023-01-01".to_string(), BDateFreq::Daily, 5)
.unwrap()
.list()
.unwrap(),
);
println!("DataFrame after initial column additions:\n{}", df);
// Demonstrate frame re-use when adding columns of existing types
let initial_frames_count = df.num_internal_frames();
println!(
"\nInitial number of internal frames: {}",
initial_frames_count
);
df.add_column("col_int2", vec![6, 7, 8, 9, 10]);
df.add_column("col_float2", vec![6.6, 7.7, 8.8, 9.9, 10.0]);
let frames_after_reuse = df.num_internal_frames();
println!(
"Number of internal frames after adding more columns of existing types: {}",
frames_after_reuse
);
assert_eq!(initial_frames_count, frames_after_reuse); // Should be equal, demonstrating re-use
println!(
"\nDataFrame after adding more columns of existing types:\n{}",
df
);
// Get number of rows and columns
println!("Rows: {}", df.rows()); // Output: Rows: 5
println!("Columns: {}", df.cols()); // Output: Columns: 5
// Get column names
println!("Column names: {:?}", df.get_column_names());
// Output: Column names: ["col_int", "col_float", "col_string", "col_bool", "col_date"]
// Get a specific column by name and type
let int_col = df.get_column::<i32>("col_int1").unwrap();
// Output: Integer column: [1, 2, 3, 4, 5]
println!("Integer column (col_int1): {:?}", int_col);
let int_col2 = df.get_column::<i32>("col_int2").unwrap();
// Output: Integer column: [6, 7, 8, 9, 10]
println!("Integer column (col_int2): {:?}", int_col2);
let float_col = df.get_column::<f64>("col_float1").unwrap();
// Output: Float column: [1.1, 2.2, 3.3, 4.4, 5.5]
println!("Float column (col_float1): {:?}", float_col);
// Attempt to get a column with incorrect type (returns None)
let wrong_type_col = df.get_column::<bool>("col_int1");
// Output: Wrong type column: None
println!("Wrong type column: {:?}", wrong_type_col);
// Get a row by index
let row_0 = df.get_row(0).unwrap();
println!("Row 0: {:?}", row_0);
// Output: Row 0: {"col_int1": "1", "col_float1": "1.1", "col_string": "apple", "col_bool": "true", "col_date": "2023-01-01", "col_int2": "6", "col_float2": "6.6"}
let row_2 = df.get_row(2).unwrap();
println!("Row 2: {:?}", row_2);
// Output: Row 2: {"col_int1": "3", "col_float1": "3.3", "col_string": "cherry", "col_bool": "true", "col_date": "2023-01-03", "col_int2": "8", "col_float2": "8.8"}
// Attempt to get an out-of-bounds row (returns None)
let row_out_of_bounds = df.get_row(10);
// Output: Row out of bounds: None
println!("Row out of bounds: {:?}", row_out_of_bounds);
// Drop a column
df.drop_column("col_bool");
println!("\nDataFrame after dropping 'col_bool':\n{}", df);
println!("Columns after drop: {}", df.cols());
println!("Column names after drop: {:?}", df.get_column_names());
// Drop another column, ensuring the underlying Frame is removed if empty
df.drop_column("col_float1");
println!("\nDataFrame after dropping 'col_float1':\n{}", df);
println!("Columns after second drop: {}", df.cols());
println!(
"Column names after second drop: {:?}",
df.get_column_names()
);
// Attempt to drop a non-existent column (will panic)
// df.drop_column("non_existent_col"); // Uncomment to see panic
```
### More examples ### More examples
See the [examples](./examples/) directory for some demonstrations of Rustframe's syntax and functionality. See the [examples](./examples/) directory for some demonstrations of Rustframe's syntax and functionality.

View File

@ -4,7 +4,7 @@ use criterion::{criterion_group, criterion_main, Criterion};
use rustframe::{ use rustframe::{
frame::{Frame, RowIndex}, frame::{Frame, RowIndex},
matrix::{BoolMatrix, Matrix, SeriesOps}, matrix::{Axis, BoolMatrix, Matrix, SeriesOps},
utils::{DateFreq, DatesList}, utils::{DateFreq, DatesList},
}; };
use std::time::Duration; use std::time::Duration;
@ -109,6 +109,47 @@ fn matrix_operations_benchmark(c: &mut Criterion, sizes: &[usize]) {
let _result = &ma / 2.0; let _result = &ma / 2.0;
}); });
}); });
c.bench_function(
&format!("matrix matrix_multiply ({}x{})", size, size),
|b| {
b.iter(|| {
let _result = ma.matrix_mul(&ma);
});
},
);
c.bench_function(&format!("matrix sum_horizontal ({}x{})", size, size), |b| {
b.iter(|| {
let _result = ma.sum_horizontal();
});
});
c.bench_function(&format!("matrix sum_vertical ({}x{})", size, size), |b| {
b.iter(|| {
let _result = ma.sum_vertical();
});
});
c.bench_function(
&format!("matrix prod_horizontal ({}x{})", size, size),
|b| {
b.iter(|| {
let _result = ma.prod_horizontal();
});
},
);
c.bench_function(&format!("matrix prod_vertical ({}x{})", size, size), |b| {
b.iter(|| {
let _result = ma.prod_vertical();
});
});
c.bench_function(&format!("matrix apply_axis ({}x{})", size, size), |b| {
b.iter(|| {
let _result = ma.apply_axis(Axis::Col, |col| col.iter().sum::<f64>());
});
});
c.bench_function(&format!("matrix transpose ({}x{})", size, size), |b| {
b.iter(|| {
let _result = ma.transpose();
});
});
} }
for &size in sizes { for &size in sizes {
@ -187,6 +228,12 @@ fn benchmark_frame_operations(c: &mut Criterion, sizes: &[usize]) {
}); });
}); });
c.bench_function(&format!("frame matrix_multiply ({}x{})", size, size), |b| {
b.iter(|| {
let _result = fa.matrix_mul(&fb);
});
});
c.bench_function(&format!("frame sum_horizontal ({}x{})", size, size), |b| { c.bench_function(&format!("frame sum_horizontal ({}x{})", size, size), |b| {
b.iter(|| { b.iter(|| {
let _result = fa.sum_horizontal(); let _result = fa.sum_horizontal();
@ -207,6 +254,16 @@ fn benchmark_frame_operations(c: &mut Criterion, sizes: &[usize]) {
let _result = fa.prod_vertical(); let _result = fa.prod_vertical();
}); });
}); });
c.bench_function(&format!("frame apply_axis ({}x{})", size, size), |b| {
b.iter(|| {
let _result = fa.apply_axis(Axis::Col, |col| col.iter().sum::<f64>());
});
});
c.bench_function(&format!("frame transpose ({}x{})", size, size), |b| {
b.iter(|| {
let _result = fa.transpose();
});
});
} }
} }

659
src/dataframe/df.rs Normal file
View File

@ -0,0 +1,659 @@
use crate::frame::{Frame, RowIndex};
use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::fmt; // Import TypeId
const DEFAULT_DISPLAY_ROWS: usize = 5;
const DEFAULT_DISPLAY_COLS: usize = 10;
// Trait to enable type-agnostic operations on Frame objects within DataFrame
pub trait SubFrame: Send + Sync + fmt::Debug + Any {
fn rows(&self) -> usize;
fn get_value_as_string(&self, physical_row_idx: usize, col_name: &str) -> String;
fn clone_box(&self) -> Box<dyn SubFrame>;
fn delete_column_from_frame(&mut self, col_name: &str);
fn get_frame_cols(&self) -> usize; // Add a method to get the number of columns in the underlying frame
// Methods for downcasting to concrete types
fn as_any(&self) -> &dyn Any;
fn as_any_mut(&mut self) -> &mut dyn Any;
}
// Implement SubFrame for any Frame<T> that meets the requirements
impl<T> SubFrame for Frame<T>
where
T: Clone + PartialEq + fmt::Display + fmt::Debug + 'static + Send + Sync + Any,
{
fn rows(&self) -> usize {
self.rows()
}
fn get_value_as_string(&self, physical_row_idx: usize, col_name: &str) -> String {
self.get_row(physical_row_idx).get(col_name).to_string()
}
fn clone_box(&self) -> Box<dyn SubFrame> {
Box::new(self.clone())
}
fn delete_column_from_frame(&mut self, col_name: &str) {
self.delete_column(col_name);
}
fn get_frame_cols(&self) -> usize {
self.cols()
}
fn as_any(&self) -> &dyn Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
}
pub struct DataFrame {
frames_by_type: HashMap<TypeId, Box<dyn SubFrame>>, // Maps TypeId to the Frame holding columns of that type
column_to_type: HashMap<String, TypeId>, // Maps column name to its TypeId
column_names: Vec<String>,
index: RowIndex,
}
impl DataFrame {
pub fn new() -> Self {
DataFrame {
frames_by_type: HashMap::new(),
column_to_type: HashMap::new(),
column_names: Vec::new(),
index: RowIndex::Range(0..0), // Initialize with an empty range index
}
}
/// Returns the number of rows in the DataFrame.
pub fn rows(&self) -> usize {
self.index.len()
}
/// Returns the number of columns in the DataFrame.
pub fn cols(&self) -> usize {
self.column_names.len()
}
/// Returns a reference to the vector of column names.
pub fn get_column_names(&self) -> &Vec<String> {
&self.column_names
}
/// Returns the number of internal Frame objects (one per unique data type).
pub fn num_internal_frames(&self) -> usize {
self.frames_by_type.len()
}
/// Returns a reference to a column of a specific type, if it exists.
pub fn get_column<T>(&self, col_name: &str) -> Option<&[T]>
where
T: Clone + PartialEq + fmt::Display + fmt::Debug + 'static + Send + Sync + Any,
{
let expected_type_id = TypeId::of::<T>();
if let Some(actual_type_id) = self.column_to_type.get(col_name) {
if *actual_type_id == expected_type_id {
if let Some(sub_frame_box) = self.frames_by_type.get(actual_type_id) {
if let Some(frame) = sub_frame_box.as_any().downcast_ref::<Frame<T>>() {
return Some(frame.column(col_name));
}
}
}
}
None
}
/// Returns a HashMap representing a row, mapping column names to their string values.
pub fn get_row(&self, row_idx: usize) -> Option<HashMap<String, String>> {
if row_idx >= self.rows() {
return None;
}
let mut row_data = HashMap::new();
for col_name in &self.column_names {
if let Some(type_id) = self.column_to_type.get(col_name) {
if let Some(sub_frame_box) = self.frames_by_type.get(type_id) {
let value = sub_frame_box.get_value_as_string(row_idx, col_name);
row_data.insert(col_name.clone(), value);
}
}
}
Some(row_data)
}
pub fn add_column<T>(&mut self, col_name: &str, data: Vec<T>)
where
T: Clone + PartialEq + fmt::Display + fmt::Debug + 'static + Send + Sync + Any,
{
let type_id = TypeId::of::<T>();
let col_name_string = col_name.to_string();
// Check for duplicate column name across the entire DataFrame
if self.column_to_type.contains_key(&col_name_string) {
panic!(
"DataFrame::add_column: duplicate column name: '{}'",
col_name_string
);
}
// If this is the first column being added, set the DataFrame's index
if self.column_names.is_empty() {
self.index = RowIndex::Range(0..data.len());
} else {
// Ensure new column has the same number of rows as existing columns
if data.len() != self.index.len() {
panic!(
"DataFrame::add_column: new column '{}' has {} rows, but existing columns have {} rows",
col_name_string,
data.len(),
self.index.len()
);
}
}
// Check if a Frame of this type already exists
if let Some(sub_frame_box) = self.frames_by_type.get_mut(&type_id) {
// Downcast to the concrete Frame<T> and add the column
if let Some(frame) = sub_frame_box.as_any_mut().downcast_mut::<Frame<T>>() {
frame.add_column(col_name_string.clone(), data);
} else {
// This should ideally not happen if TypeId matches, but good for safety
panic!(
"Type mismatch when downcasting existing SubFrame for TypeId {:?}",
type_id
);
}
} else {
// No Frame of this type exists, create a new one
// The Frame::new constructor expects a Matrix and column names.
// We create a Matrix from a single column vector.
let new_frame = Frame::new(
crate::matrix::Matrix::from_cols(vec![data]),
vec![col_name_string.clone()],
Some(self.index.clone()), // Pass the DataFrame's index to the new Frame
);
self.frames_by_type.insert(type_id, Box::new(new_frame));
}
// Update column mappings and names
self.column_to_type.insert(col_name_string.clone(), type_id);
self.column_names.push(col_name_string);
}
/// Drops a column from the DataFrame.
/// Panics if the column does not exist.
pub fn drop_column(&mut self, col_name: &str) {
let col_name_string = col_name.to_string();
// 1. Get the TypeId associated with the column
let type_id = self
.column_to_type
.remove(&col_name_string)
.unwrap_or_else(|| {
panic!(
"DataFrame::drop_column: column '{}' not found",
col_name_string
);
});
// 2. Remove the column name from the ordered list
self.column_names.retain(|name| name != &col_name_string);
// 3. Find the Frame object and delete the column from it
if let Some(sub_frame_box) = self.frames_by_type.get_mut(&type_id) {
sub_frame_box.delete_column_from_frame(&col_name_string);
// 4. If the Frame object for this type becomes empty, remove it from frames_by_type
if sub_frame_box.get_frame_cols() == 0 {
self.frames_by_type.remove(&type_id);
}
} else {
// This should not happen if column_to_type was consistent
panic!(
"DataFrame::drop_column: internal error, no frame found for type_id {:?}",
type_id
);
}
}
}
impl fmt::Display for DataFrame {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
// Display column headers
for col_name in self.column_names.iter().take(DEFAULT_DISPLAY_COLS) {
write!(f, "{:<15}", col_name)?;
}
if self.column_names.len() > DEFAULT_DISPLAY_COLS {
write!(f, "...")?;
}
writeln!(f)?;
// Display data rows
let mut displayed_rows = 0;
for i in 0..self.index.len() {
if displayed_rows >= DEFAULT_DISPLAY_ROWS {
writeln!(f, "...")?;
break;
}
for col_name in self.column_names.iter().take(DEFAULT_DISPLAY_COLS) {
if let Some(type_id) = self.column_to_type.get(col_name) {
if let Some(sub_frame_box) = self.frames_by_type.get(type_id) {
write!(f, "{:<15}", sub_frame_box.get_value_as_string(i, col_name))?;
} else {
// This case indicates an inconsistency: column_to_type has an entry,
// but frames_by_type doesn't have the corresponding Frame.
write!(f, "{:<15}", "[ERROR]")?;
}
} else {
// This case indicates an inconsistency: column_names has an entry,
// but column_to_type doesn't have the corresponding column.
write!(f, "{:<15}", "[ERROR]")?;
}
}
if self.column_names.len() > DEFAULT_DISPLAY_COLS {
write!(f, "...")?;
}
writeln!(f)?;
displayed_rows += 1;
}
Ok(())
}
}
impl fmt::Debug for DataFrame {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("DataFrame")
.field("column_names", &self.column_names)
.field("index", &self.index)
.field("column_to_type", &self.column_to_type)
.field("frames_by_type", &self.frames_by_type)
.finish()
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#[cfg(test)]
mod tests {
use super::*;
use crate::frame::Frame;
use crate::matrix::Matrix;
#[test]
fn test_dataframe_new() {
let df = DataFrame::new();
assert_eq!(df.rows(), 0);
assert_eq!(df.cols(), 0);
assert!(df.get_column_names().is_empty());
assert!(df.frames_by_type.is_empty());
assert!(df.column_to_type.is_empty());
}
#[test]
fn test_dataframe_add_column_initial() {
let mut df = DataFrame::new();
let data = vec![1, 2, 3];
df.add_column("col_int", data.clone());
assert_eq!(df.rows(), 3);
assert_eq!(df.cols(), 1);
assert_eq!(df.get_column_names(), &vec!["col_int".to_string()]);
assert!(df.frames_by_type.contains_key(&TypeId::of::<i32>()));
assert_eq!(df.column_to_type.get("col_int"), Some(&TypeId::of::<i32>()));
// Verify the underlying frame
let sub_frame_box = df.frames_by_type.get(&TypeId::of::<i32>()).unwrap();
let frame = sub_frame_box.as_any().downcast_ref::<Frame<i32>>().unwrap();
assert_eq!(frame.rows(), 3);
assert_eq!(frame.cols(), 1);
assert_eq!(frame.columns(), &vec!["col_int".to_string()]);
}
#[test]
fn test_dataframe_add_column_same_type() {
let mut df = DataFrame::new();
df.add_column("col_int1", vec![1, 2, 3]);
df.add_column("col_int2", vec![4, 5, 6]);
assert_eq!(df.rows(), 3);
assert_eq!(df.cols(), 2);
assert_eq!(
df.get_column_names(),
&vec!["col_int1".to_string(), "col_int2".to_string()]
);
assert!(df.frames_by_type.contains_key(&TypeId::of::<i32>()));
assert_eq!(
df.column_to_type.get("col_int1"),
Some(&TypeId::of::<i32>())
);
assert_eq!(
df.column_to_type.get("col_int2"),
Some(&TypeId::of::<i32>())
);
// Verify the underlying frame
let sub_frame_box = df.frames_by_type.get(&TypeId::of::<i32>()).unwrap();
let frame = sub_frame_box.as_any().downcast_ref::<Frame<i32>>().unwrap();
assert_eq!(frame.rows(), 3);
assert_eq!(frame.cols(), 2);
assert_eq!(
frame.columns(),
&vec!["col_int1".to_string(), "col_int2".to_string()]
);
}
#[test]
fn test_dataframe_add_column_different_type() {
let mut df = DataFrame::new();
df.add_column("col_int", vec![1, 2, 3]);
df.add_column("col_float", vec![1.1, 2.2, 3.3]);
df.add_column(
"col_string",
vec!["a".to_string(), "b".to_string(), "c".to_string()],
);
assert_eq!(df.rows(), 3);
assert_eq!(df.cols(), 3);
assert_eq!(
df.get_column_names(),
&vec![
"col_int".to_string(),
"col_float".to_string(),
"col_string".to_string()
]
);
assert!(df.frames_by_type.contains_key(&TypeId::of::<i32>()));
assert!(df.frames_by_type.contains_key(&TypeId::of::<f64>()));
assert!(df.frames_by_type.contains_key(&TypeId::of::<String>()));
assert_eq!(df.column_to_type.get("col_int"), Some(&TypeId::of::<i32>()));
assert_eq!(
df.column_to_type.get("col_float"),
Some(&TypeId::of::<f64>())
);
assert_eq!(
df.column_to_type.get("col_string"),
Some(&TypeId::of::<String>())
);
// Verify underlying frames
let int_frame = df
.frames_by_type
.get(&TypeId::of::<i32>())
.unwrap()
.as_any()
.downcast_ref::<Frame<i32>>()
.unwrap();
assert_eq!(int_frame.columns(), &vec!["col_int".to_string()]);
let float_frame = df
.frames_by_type
.get(&TypeId::of::<f64>())
.unwrap()
.as_any()
.downcast_ref::<Frame<f64>>()
.unwrap();
assert_eq!(float_frame.columns(), &vec!["col_float".to_string()]);
let string_frame = df
.frames_by_type
.get(&TypeId::of::<String>())
.unwrap()
.as_any()
.downcast_ref::<Frame<String>>()
.unwrap();
assert_eq!(string_frame.columns(), &vec!["col_string".to_string()]);
}
#[test]
fn test_dataframe_get_column() {
let mut df = DataFrame::new();
df.add_column("col_int", vec![1, 2, 3]);
df.add_column("col_float", vec![1.1, 2.2, 3.3]);
df.add_column(
"col_string",
vec!["a".to_string(), "b".to_string(), "c".to_string()],
);
// Test getting existing columns with correct type
assert_eq!(
df.get_column::<i32>("col_int").unwrap(),
vec![1, 2, 3].as_slice()
);
assert_eq!(
df.get_column::<f64>("col_float").unwrap(),
vec![1.1, 2.2, 3.3].as_slice()
);
assert_eq!(
df.get_column::<String>("col_string").unwrap(),
vec!["a".to_string(), "b".to_string(), "c".to_string()].as_slice()
);
// Test getting non-existent column
assert_eq!(df.get_column::<i32>("non_existent"), None);
// Test getting existing column with incorrect type
assert_eq!(df.get_column::<f64>("col_int"), None);
assert_eq!(df.get_column::<i32>("col_float"), None);
}
#[test]
fn test_dataframe_get_row() {
let mut df = DataFrame::new();
df.add_column("col_int", vec![1, 2, 3]);
df.add_column("col_float", vec![1.1, 2.2, 3.3]);
df.add_column(
"col_string",
vec!["a".to_string(), "b".to_string(), "c".to_string()],
);
// Test getting an existing row
let row0 = df.get_row(0).unwrap();
assert_eq!(row0.get("col_int"), Some(&"1".to_string()));
assert_eq!(row0.get("col_float"), Some(&"1.1".to_string()));
assert_eq!(row0.get("col_string"), Some(&"a".to_string()));
let row1 = df.get_row(1).unwrap();
assert_eq!(row1.get("col_int"), Some(&"2".to_string()));
assert_eq!(row1.get("col_float"), Some(&"2.2".to_string()));
assert_eq!(row1.get("col_string"), Some(&"b".to_string()));
// Test getting an out-of-bounds row
assert_eq!(df.get_row(3), None);
}
#[test]
#[should_panic(expected = "DataFrame::add_column: duplicate column name: 'col_int'")]
fn test_dataframe_add_column_duplicate_name() {
let mut df = DataFrame::new();
df.add_column("col_int", vec![1, 2, 3]);
df.add_column("col_int", vec![4, 5, 6]);
}
#[test]
#[should_panic(
expected = "DataFrame::add_column: new column 'col_int2' has 2 rows, but existing columns have 3 rows"
)]
fn test_dataframe_add_column_mismatched_rows() {
let mut df = DataFrame::new();
df.add_column("col_int1", vec![1, 2, 3]);
df.add_column("col_int2", vec![4, 5]);
}
#[test]
fn test_dataframe_display() {
let mut df = DataFrame::new();
df.add_column("col_int", vec![1, 2, 3, 4, 5, 6]);
df.add_column("col_float", vec![1.1, 2.2, 3.3, 4.4, 5.5, 6.6]);
df.add_column(
"col_string",
vec![
"a".to_string(),
"b".to_string(),
"c".to_string(),
"d".to_string(),
"e".to_string(),
"f".to_string(),
],
);
let expected_output = "\
col_int col_float col_string
1 1.1 a
2 2.2 b
3 3.3 c
4 4.4 d
5 5.5 e
...
";
assert_eq!(format!("{}", df), expected_output);
}
#[test]
fn test_dataframe_debug() {
let mut df = DataFrame::new();
df.add_column("col_int", vec![1, 2, 3]);
df.add_column("col_float", vec![1.1, 2.2, 3.3]);
let debug_output = format!("{:?}", df);
assert!(debug_output.contains("DataFrame {"));
assert!(debug_output.contains("column_names: [\"col_int\", \"col_float\"]"));
assert!(debug_output.contains("index: Range(0..3)"));
assert!(debug_output.contains("column_to_type: {"));
assert!(debug_output.contains("frames_by_type: {"));
}
#[test]
fn test_dataframe_drop_column_single_type() {
let mut df = DataFrame::new();
df.add_column("col_int1", vec![1, 2, 3]);
df.add_column("col_int2", vec![4, 5, 6]);
df.add_column("col_float", vec![1.1, 2.2, 3.3]);
assert_eq!(df.cols(), 3);
assert_eq!(
df.get_column_names(),
&vec![
"col_int1".to_string(),
"col_int2".to_string(),
"col_float".to_string()
]
);
assert!(df.frames_by_type.contains_key(&TypeId::of::<i32>()));
assert!(df.frames_by_type.contains_key(&TypeId::of::<f64>()));
df.drop_column("col_int1");
assert_eq!(df.cols(), 2);
assert_eq!(
df.get_column_names(),
&vec!["col_int2".to_string(), "col_float".to_string()]
);
assert!(df.column_to_type.get("col_int1").is_none());
assert!(df.frames_by_type.contains_key(&TypeId::of::<i32>())); // Frame<i32> should still exist
let int_frame = df
.frames_by_type
.get(&TypeId::of::<i32>())
.unwrap()
.as_any()
.downcast_ref::<Frame<i32>>()
.unwrap();
assert_eq!(int_frame.columns(), &vec!["col_int2".to_string()]);
df.drop_column("col_int2");
assert_eq!(df.cols(), 1);
assert_eq!(df.get_column_names(), &vec!["col_float".to_string()]);
assert!(df.column_to_type.get("col_int2").is_none());
assert!(!df.frames_by_type.contains_key(&TypeId::of::<i32>())); // Frame<i32> should be removed
assert!(df.frames_by_type.contains_key(&TypeId::of::<f64>()));
}
#[test]
fn test_dataframe_drop_column_mixed_types() {
let mut df = DataFrame::new();
df.add_column("col_int", vec![1, 2, 3]);
df.add_column("col_float", vec![1.1, 2.2, 3.3]);
df.add_column(
"col_string",
vec!["a".to_string(), "b".to_string(), "c".to_string()],
);
assert_eq!(df.cols(), 3);
assert!(df.frames_by_type.contains_key(&TypeId::of::<i32>()));
assert!(df.frames_by_type.contains_key(&TypeId::of::<f64>()));
assert!(df.frames_by_type.contains_key(&TypeId::of::<String>()));
df.drop_column("col_float");
assert_eq!(df.cols(), 2);
assert_eq!(
df.get_column_names(),
&vec!["col_int".to_string(), "col_string".to_string()]
);
assert!(df.column_to_type.get("col_float").is_none());
assert!(!df.frames_by_type.contains_key(&TypeId::of::<f64>())); // Frame<f64> should be removed
assert!(df.frames_by_type.contains_key(&TypeId::of::<i32>()));
assert!(df.frames_by_type.contains_key(&TypeId::of::<String>()));
df.drop_column("col_int");
df.drop_column("col_string");
assert_eq!(df.cols(), 0);
assert!(df.get_column_names().is_empty());
assert!(df.frames_by_type.is_empty());
assert!(df.column_to_type.is_empty());
}
#[test]
#[should_panic(expected = "DataFrame::drop_column: column 'non_existent' not found")]
fn test_dataframe_drop_column_non_existent() {
let mut df = DataFrame::new();
df.add_column("col_int", vec![1, 2, 3]);
df.drop_column("non_existent");
}
#[test]
fn test_dataframe_add_column_reuses_existing_frame() {
let mut df = DataFrame::new();
df.add_column("col_int1", vec![1, 2, 3]);
df.add_column("col_float1", vec![1.1, 2.2, 3.3]);
// Initially, there should be two frames (one for i32, one for f64)
assert_eq!(df.frames_by_type.len(), 2);
assert!(df.frames_by_type.contains_key(&TypeId::of::<i32>()));
assert!(df.frames_by_type.contains_key(&TypeId::of::<f64>()));
// Add another integer column
df.add_column("col_int2", vec![4, 5, 6]);
// The number of frames should still be 2, as the existing i32 frame should be reused
assert_eq!(df.frames_by_type.len(), 2);
assert!(df.frames_by_type.contains_key(&TypeId::of::<i32>()));
assert!(df.frames_by_type.contains_key(&TypeId::of::<f64>()));
// Verify the i32 frame now contains both integer columns
let int_frame = df.frames_by_type.get(&TypeId::of::<i32>()).unwrap().as_any().downcast_ref::<Frame<i32>>().unwrap();
assert_eq!(int_frame.columns(), &vec!["col_int1".to_string(), "col_int2".to_string()]);
assert_eq!(int_frame.cols(), 2);
// Add another float column
df.add_column("col_float2", vec![4.4, 5.5, 6.6]);
// The number of frames should still be 2, as the existing f64 frame should be reused
assert_eq!(df.frames_by_type.len(), 2);
assert!(df.frames_by_type.contains_key(&TypeId::of::<i32>()));
assert!(df.frames_by_type.contains_key(&TypeId::of::<f64>()));
// Verify the f64 frame now contains both float columns
let float_frame = df.frames_by_type.get(&TypeId::of::<f64>()).unwrap().as_any().downcast_ref::<Frame<f64>>().unwrap();
assert_eq!(float_frame.columns(), &vec!["col_float1".to_string(), "col_float2".to_string()]);
assert_eq!(float_frame.cols(), 2);
}
}

4
src/dataframe/mod.rs Normal file
View File

@ -0,0 +1,4 @@
//! This module provides the DataFrame structure for handling tabular data with mixed types.
pub mod df;
pub use df::{DataFrame, SubFrame};

View File

@ -316,7 +316,7 @@ impl<T: Clone + PartialEq> Frame<T> {
) )
} }
/// Returns an immutable slice of the specified column's data. /// Returns an immutable slice of the specified column's data by name.
/// Panics if the column name is not found. /// Panics if the column name is not found.
pub fn column(&self, name: &str) -> &[T] { pub fn column(&self, name: &str) -> &[T] {
let idx = self let idx = self
@ -325,7 +325,13 @@ impl<T: Clone + PartialEq> Frame<T> {
self.matrix.column(idx) self.matrix.column(idx)
} }
/// Returns a mutable slice of the specified column's data. /// Returns an immutable slice of the specified column's data by its physical index.
/// Panics if the index is out of bounds.
pub fn column_by_physical_idx(&self, idx: usize) -> &[T] {
self.matrix.column(idx)
}
/// Returns a mutable slice of the specified column's data by name.
/// Panics if the column name is not found. /// Panics if the column name is not found.
pub fn column_mut(&mut self, name: &str) -> &mut [T] { pub fn column_mut(&mut self, name: &str) -> &mut [T] {
let idx = self let idx = self
@ -334,6 +340,12 @@ impl<T: Clone + PartialEq> Frame<T> {
self.matrix.column_mut(idx) self.matrix.column_mut(idx)
} }
/// Returns a mutable slice of the specified column's data by its physical index.
/// Panics if the index is out of bounds.
pub fn column_mut_by_physical_idx(&mut self, idx: usize) -> &mut [T] {
self.matrix.column_mut(idx)
}
// Row access methods // Row access methods
/// Returns an immutable view of the row for the given integer key. /// Returns an immutable view of the row for the given integer key.
@ -471,6 +483,12 @@ impl<T: Clone + PartialEq> Frame<T> {
deleted_data deleted_data
} }
/// Returns a new `Matrix` that is the transpose of the current frame's matrix.
pub fn transpose(&self) -> Matrix<T> {
self.matrix.transpose()
}
/// Sorts columns alphabetically by name, preserving data associations. /// Sorts columns alphabetically by name, preserving data associations.
pub fn sort_columns(&mut self) { pub fn sort_columns(&mut self) {
let n = self.column_names.len(); let n = self.column_names.len();

View File

@ -20,6 +20,13 @@ impl SeriesOps for Frame<f64> {
{ {
self.matrix().apply_axis(axis, f) self.matrix().apply_axis(axis, f)
} }
fn matrix_mul(&self, other: &Self) -> FloatMatrix {
self.matrix().matrix_mul(other.matrix())
}
fn dot(&self, other: &Self) -> FloatMatrix {
self.matrix().dot(other.matrix())
}
delegate_to_matrix!( delegate_to_matrix!(
sum_vertical -> Vec<f64>, sum_vertical -> Vec<f64>,
@ -128,6 +135,17 @@ mod tests {
); );
assert_eq!(frame.is_nan(), frame.matrix().is_nan()); assert_eq!(frame.is_nan(), frame.matrix().is_nan());
assert_eq!(frame.apply_axis(Axis::Row, |x| x[0] + x[1]), vec![4.0, 6.0]); assert_eq!(frame.apply_axis(Axis::Row, |x| x[0] + x[1]), vec![4.0, 6.0]);
assert_eq!(frame.matrix_mul(&frame), frame.matrix().matrix_mul(&frame.matrix()));
assert_eq!(frame.dot(&frame), frame.matrix().dot(&frame.matrix()));
// test transpose - returns a matrix.
let frame_transposed_mat = frame.transpose();
let frame_mat_transposed = frame.matrix().transpose();
assert_eq!(frame_transposed_mat, frame_mat_transposed);
assert_eq!(frame.matrix(), &frame.matrix().transpose().transpose());
} }
#[test] #[test]
@ -152,4 +170,7 @@ mod tests {
vec![false, false] vec![false, false]
); );
} }
} }

View File

@ -1,5 +1,8 @@
#![doc = include_str!("../README.md")] #![doc = include_str!("../README.md")]
/// Documentation for the [`crate::dataframe`] module.
pub mod dataframe;
/// Documentation for the [`crate::matrix`] module. /// Documentation for the [`crate::matrix`] module.
pub mod matrix; pub mod matrix;

View File

@ -168,7 +168,12 @@ impl<T: Clone> Matrix<T> {
/// Deletes a column from the matrix. Panics on out-of-bounds. /// Deletes a column from the matrix. Panics on out-of-bounds.
/// This is O(N) where N is the number of elements. /// This is O(N) where N is the number of elements.
pub fn delete_column(&mut self, col: usize) { pub fn delete_column(&mut self, col: usize) {
assert!(col < self.cols, "column index {} out of bounds for {} columns", col, self.cols); assert!(
col < self.cols,
"column index {} out of bounds for {} columns",
col,
self.cols
);
let start = col * self.rows; let start = col * self.rows;
self.data.drain(start..start + self.rows); // Efficient removal self.data.drain(start..start + self.rows); // Efficient removal
self.cols -= 1; self.cols -= 1;
@ -203,6 +208,22 @@ impl<T: Clone> Matrix<T> {
self.data = new_data; self.data = new_data;
self.rows = new_rows; self.rows = new_rows;
} }
pub fn transpose(&self) -> Matrix<T> {
let (m, n) = (self.rows, self.cols);
let mut transposed_data = Vec::with_capacity(m * n);
// In the transposed matrix the old rows become the new columns.
for j in 0..m {
// new column index = old row index
for i in 0..n {
// new row index = old col index
transposed_data.push(self[(j, i)].clone()); // A(T)[i,j] = A[j,i]
}
}
Matrix::from_vec(transposed_data, n, m) // size is n × m
}
} }
impl<T: Clone> Matrix<T> { impl<T: Clone> Matrix<T> {
@ -238,8 +259,19 @@ impl<T: Clone> Matrix<T> {
/// Adds a row to the matrix at the specified index. Panics if index > rows or length mismatch. /// Adds a row to the matrix at the specified index. Panics if index > rows or length mismatch.
/// This is O(N) where N is the number of elements, as it rebuilds the data vec. /// This is O(N) where N is the number of elements, as it rebuilds the data vec.
pub fn add_row(&mut self, index: usize, row: Vec<T>) { pub fn add_row(&mut self, index: usize, row: Vec<T>) {
assert!(index <= self.rows, "add_row index {} out of bounds for {} rows", index, self.rows); assert!(
assert_eq!(row.len(), self.cols, "row length mismatch: expected {} (cols), got {}", self.cols, row.len()); index <= self.rows,
"add_row index {} out of bounds for {} rows",
index,
self.rows
);
assert_eq!(
row.len(),
self.cols,
"row length mismatch: expected {} (cols), got {}",
self.cols,
row.len()
);
if self.cols == 0 && self.rows == 0 { if self.cols == 0 && self.rows == 0 {
// Special case: adding first row to empty matrix // Special case: adding first row to empty matrix
@ -624,7 +656,6 @@ pub type BoolMatrix = Matrix<bool>;
pub type IntMatrix = Matrix<i32>; pub type IntMatrix = Matrix<i32>;
pub type StringMatrix = Matrix<String>; pub type StringMatrix = Matrix<String>;
// --- Unit Tests --- // --- Unit Tests ---
#[cfg(test)] #[cfg(test)]
@ -903,6 +934,40 @@ mod tests {
assert_eq!(matrix[(1, 2)], 6); // Second row, third col assert_eq!(matrix[(1, 2)], 6); // Second row, third col
} }
#[test]
fn test_transpose() {
let matrix = static_test_matrix();
let transposed = matrix.transpose();
let round_triped = transposed.transpose();
assert_eq!(
round_triped, matrix,
"Transposing twice should return original matrix"
);
for r in 0..matrix.rows() {
for c in 0..matrix.cols() {
assert_eq!(matrix[(r, c)], transposed[(c, r)]);
}
}
}
#[test]
fn test_transpose_big() {
let data: Vec<i32> = (1..=20000).collect(); //
let matrix = Matrix::from_vec(data, 100, 200);
let transposed = matrix.transpose();
assert_eq!(transposed.rows(), 200);
assert_eq!(transposed.cols(), 100);
assert_eq!(transposed.data().len(), 20000);
assert_eq!(transposed[(0, 0)], 1);
let round_trip = transposed.transpose();
assert_eq!(
round_trip, matrix,
"Transposing back should return original matrix"
);
}
#[test] #[test]
#[should_panic(expected = "data length mismatch")] #[should_panic(expected = "data length mismatch")]
fn test_from_vec_wrong_length() { fn test_from_vec_wrong_length() {
@ -999,20 +1064,20 @@ mod tests {
#[test] #[test]
#[should_panic(expected = "index out of bounds")] #[should_panic(expected = "index out of bounds")]
fn test_index_out_of_bounds_row_alt() { fn test_index_out_of_bounds_row_alt() {
let matrix = static_test_matrix(); // 3x3 let matrix = static_test_matrix();
let _ = matrix[(3, 0)]; let _ = matrix[(3, 0)];
} }
#[test] #[test]
#[should_panic(expected = "index out of bounds")] #[should_panic(expected = "index out of bounds")]
fn test_index_out_of_bounds_col_alt() { fn test_index_out_of_bounds_col_alt() {
let matrix = static_test_matrix(); // 3x3 let matrix = static_test_matrix();
let _ = matrix[(0, 3)]; let _ = matrix[(0, 3)];
} }
#[test] #[test]
fn test_index_mut_and_get_mut() { fn test_index_mut_and_get_mut() {
let mut matrix = static_test_matrix(); // 3x3 let mut matrix = static_test_matrix();
matrix[(0, 0)] = 10; matrix[(0, 0)] = 10;
matrix[(1, 1)] = 20; matrix[(1, 1)] = 20;
@ -1028,33 +1093,26 @@ mod tests {
assert_eq!(matrix[(0, 1)], 15); assert_eq!(matrix[(0, 1)], 15);
assert_eq!(matrix[(2, 1)], 25); assert_eq!(matrix[(2, 1)], 25);
// Check underlying data consistency (column-major)
// Should be:
// 10 15 7
// 2 20 8
// 3 25 30
assert_eq!(matrix.data(), &[10, 2, 3, 15, 20, 25, 7, 8, 30]); assert_eq!(matrix.data(), &[10, 2, 3, 15, 20, 25, 7, 8, 30]);
} }
#[test] #[test]
#[should_panic(expected = "index out of bounds")] #[should_panic(expected = "index out of bounds")]
fn test_index_mut_out_of_bounds_row() { fn test_index_mut_out_of_bounds_row() {
let mut matrix = static_test_matrix(); // 3x3 let mut matrix = static_test_matrix();
matrix[(3, 0)] = 99; matrix[(3, 0)] = 99;
} }
#[test] #[test]
#[should_panic(expected = "index out of bounds")] #[should_panic(expected = "index out of bounds")]
fn test_index_mut_out_of_bounds_col() { fn test_index_mut_out_of_bounds_col() {
let mut matrix = static_test_matrix(); // 3x3 let mut matrix = static_test_matrix();
matrix[(0, 3)] = 99; matrix[(0, 3)] = 99;
} }
#[test] #[test]
fn test_column() { fn test_column() {
let matrix = static_test_matrix_2x4(); // 2x4 let matrix = static_test_matrix_2x4();
// 1 3 5 7
// 2 4 6 8
assert_eq!(matrix.column(0), &[1, 2]); assert_eq!(matrix.column(0), &[1, 2]);
assert_eq!(matrix.column(1), &[3, 4]); assert_eq!(matrix.column(1), &[3, 4]);
@ -1065,15 +1123,13 @@ mod tests {
#[test] #[test]
#[should_panic(expected = "column index 4 out of bounds for 4 columns")] #[should_panic(expected = "column index 4 out of bounds for 4 columns")]
fn test_column_out_of_bounds() { fn test_column_out_of_bounds() {
let matrix = static_test_matrix_2x4(); // 2x4 let matrix = static_test_matrix_2x4();
matrix.column(4); matrix.column(4);
} }
#[test] #[test]
fn test_column_mut() { fn test_column_mut() {
let mut matrix = static_test_matrix_2x4(); // 2x4 let mut matrix = static_test_matrix_2x4();
// 1 3 5 7
// 2 4 6 8
let col1_mut = matrix.column_mut(1); let col1_mut = matrix.column_mut(1);
col1_mut[0] = 30; col1_mut[0] = 30;
@ -1088,25 +1144,19 @@ mod tests {
assert_eq!(matrix[(0, 3)], 70); assert_eq!(matrix[(0, 3)], 70);
assert_eq!(matrix[(1, 3)], 8); // Unchanged assert_eq!(matrix[(1, 3)], 8); // Unchanged
// Check underlying data (column-major)
// Should be:
// 1 30 5 70
// 2 40 6 8
assert_eq!(matrix.data(), &[1, 2, 30, 40, 5, 6, 70, 8]); assert_eq!(matrix.data(), &[1, 2, 30, 40, 5, 6, 70, 8]);
} }
#[test] #[test]
#[should_panic(expected = "column index 4 out of bounds for 4 columns")] #[should_panic(expected = "column index 4 out of bounds for 4 columns")]
fn test_column_mut_out_of_bounds() { fn test_column_mut_out_of_bounds() {
let mut matrix = static_test_matrix_2x4(); // 2x4 let mut matrix = static_test_matrix_2x4();
matrix.column_mut(4); matrix.column_mut(4);
} }
#[test] #[test]
fn test_iter_columns() { fn test_iter_columns() {
let matrix = static_test_matrix_2x4(); // 2x4 let matrix = static_test_matrix_2x4();
// 1 3 5 7
// 2 4 6 8
let cols: Vec<&[i32]> = matrix.iter_columns().collect(); let cols: Vec<&[i32]> = matrix.iter_columns().collect();
assert_eq!(cols.len(), 4); assert_eq!(cols.len(), 4);
@ -1118,9 +1168,7 @@ mod tests {
#[test] #[test]
fn test_iter_rows() { fn test_iter_rows() {
let matrix = static_test_matrix_2x4(); // 2x4 let matrix = static_test_matrix_2x4();
// 1 3 5 7
// 2 4 6 8
let rows: Vec<Vec<i32>> = matrix let rows: Vec<Vec<i32>> = matrix
.iter_rows() .iter_rows()
@ -1134,10 +1182,7 @@ mod tests {
// test data_mut // test data_mut
#[test] #[test]
fn test_data_mut() { fn test_data_mut() {
let mut matrix = static_test_matrix(); // 3x3 let mut matrix = static_test_matrix();
// 1 4 7
// 2 5 8
// 3 6 9
let data_mut = matrix.data_mut(); let data_mut = matrix.data_mut();
data_mut[0] = 10; data_mut[0] = 10;
@ -1149,10 +1194,7 @@ mod tests {
#[test] #[test]
fn test_matrix_row_get_and_iter() { fn test_matrix_row_get_and_iter() {
let matrix = static_test_matrix_2x4(); // 2x4 let matrix = static_test_matrix_2x4();
// 1 3 5 7
// 2 4 6 8
let row0 = matrix.iter_rows().next().unwrap(); let row0 = matrix.iter_rows().next().unwrap();
assert_eq!(*row0.get(0), 1); assert_eq!(*row0.get(0), 1);
assert_eq!(*row0.get(1), 3); assert_eq!(*row0.get(1), 3);
@ -1169,17 +1211,9 @@ mod tests {
#[test] #[test]
fn test_swap_columns() { fn test_swap_columns() {
let mut matrix = static_test_matrix(); // 3x3 let mut matrix = static_test_matrix();
// 1 4 7
// 2 5 8
// 3 6 9
matrix.swap_columns(0, 2); // Swap first and last matrix.swap_columns(0, 2); // swap first and last
// Should be:
// 7 4 1
// 8 5 2
// 9 6 3
assert_eq!(matrix.rows(), 3); assert_eq!(matrix.rows(), 3);
assert_eq!(matrix.cols(), 3); assert_eq!(matrix.cols(), 3);
@ -1193,7 +1227,7 @@ mod tests {
assert_eq!(matrix[(1, 2)], 2); assert_eq!(matrix[(1, 2)], 2);
assert_eq!(matrix[(2, 2)], 3); assert_eq!(matrix[(2, 2)], 3);
// Swap the same column (should do nothing) // swap the same column (should do nothing)
let original_data = matrix.data().to_vec(); let original_data = matrix.data().to_vec();
matrix.swap_columns(1, 1); matrix.swap_columns(1, 1);
assert_eq!(matrix.data(), &original_data); // Data should be identical assert_eq!(matrix.data(), &original_data); // Data should be identical
@ -1205,22 +1239,15 @@ mod tests {
#[test] #[test]
#[should_panic(expected = "column index c2=3 out of bounds for 3 columns")] #[should_panic(expected = "column index c2=3 out of bounds for 3 columns")]
fn test_swap_columns_out_of_bounds() { fn test_swap_columns_out_of_bounds() {
let mut matrix = static_test_matrix(); // 3x3 let mut matrix = static_test_matrix();
matrix.swap_columns(0, 3); matrix.swap_columns(0, 3);
} }
#[test] #[test]
fn test_delete_column() { fn test_delete_column() {
let mut matrix = static_test_matrix_2x4(); // 2x4 let mut matrix = static_test_matrix_2x4();
// 1 3 5 7
// 2 4 6 8
matrix.delete_column(1); // Delete the second column (index 1) matrix.delete_column(1); // Delete the second column (index 1)
// Should be:
// 1 5 7
// 2 6 8
assert_eq!(matrix.rows(), 2); assert_eq!(matrix.rows(), 2);
assert_eq!(matrix.cols(), 3); assert_eq!(matrix.cols(), 3);
assert_eq!(matrix[(0, 0)], 1); assert_eq!(matrix[(0, 0)], 1);
@ -1230,23 +1257,17 @@ mod tests {
assert_eq!(matrix[(0, 2)], 7); assert_eq!(matrix[(0, 2)], 7);
assert_eq!(matrix[(1, 2)], 8); assert_eq!(matrix[(1, 2)], 8);
// Check underlying data (column-major) // check underlying data
assert_eq!(matrix.data(), &[1, 2, 5, 6, 7, 8]); assert_eq!(matrix.data(), &[1, 2, 5, 6, 7, 8]);
// Delete the first column // Delete the first column
matrix.delete_column(0); matrix.delete_column(0);
// Should be:
// 5 7
// 6 8
assert_eq!(matrix.rows(), 2); assert_eq!(matrix.rows(), 2);
assert_eq!(matrix.cols(), 2); assert_eq!(matrix.cols(), 2);
assert_eq!(matrix.data(), &[5, 6, 7, 8]); assert_eq!(matrix.data(), &[5, 6, 7, 8]);
// Delete the last column // Delete the last column
matrix.delete_column(1); matrix.delete_column(1);
// Should be:
// 5
// 6
assert_eq!(matrix.rows(), 2); assert_eq!(matrix.rows(), 2);
assert_eq!(matrix.cols(), 1); assert_eq!(matrix.cols(), 1);
assert_eq!(matrix.data(), &[5, 6]); assert_eq!(matrix.data(), &[5, 6]);
@ -1262,22 +1283,15 @@ mod tests {
#[test] #[test]
#[should_panic(expected = "column index 4 out of bounds for 4 columns")] #[should_panic(expected = "column index 4 out of bounds for 4 columns")]
fn test_delete_column_out_of_bounds() { fn test_delete_column_out_of_bounds() {
let mut matrix = static_test_matrix_2x4(); // 2x4 let mut matrix = static_test_matrix_2x4();
matrix.delete_column(4); matrix.delete_column(4);
} }
#[test] #[test]
fn test_delete_row() { fn test_delete_row() {
let mut matrix = static_test_matrix(); // 3x3 let mut matrix = static_test_matrix();
// 1 4 7
// 2 5 8
// 3 6 9
matrix.delete_row(1); // Delete the second row (index 1) matrix.delete_row(1); // Delete the second row
// Should be:
// 1 4 7
// 3 6 9
assert_eq!(matrix.rows(), 2); assert_eq!(matrix.rows(), 2);
assert_eq!(matrix.cols(), 3); assert_eq!(matrix.cols(), 3);
@ -1288,9 +1302,7 @@ mod tests {
assert_eq!(matrix[(0, 2)], 7); assert_eq!(matrix[(0, 2)], 7);
assert_eq!(matrix[(1, 2)], 9); assert_eq!(matrix[(1, 2)], 9);
// Check underlying data (column-major) // check underlying data (column-major)
// Original: [1, 2, 3, 4, 5, 6, 7, 8, 9]
// Delete row 1: [1, 3, 4, 6, 7, 9]
assert_eq!(matrix.data(), &[1, 3, 4, 6, 7, 9]); assert_eq!(matrix.data(), &[1, 3, 4, 6, 7, 9]);
// Delete the first row // Delete the first row
@ -1312,22 +1324,15 @@ mod tests {
#[test] #[test]
#[should_panic(expected = "row index 3 out of bounds for 3 rows")] #[should_panic(expected = "row index 3 out of bounds for 3 rows")]
fn test_delete_row_out_of_bounds() { fn test_delete_row_out_of_bounds() {
let mut matrix = static_test_matrix(); // 3x3 let mut matrix = static_test_matrix();
matrix.delete_row(3); matrix.delete_row(3);
} }
#[test] #[test]
fn test_add_column() { fn test_add_column() {
let mut matrix = static_test_matrix_2x4(); // 2x4 let mut matrix = static_test_matrix_2x4();
// 1 3 5 7
// 2 4 6 8
let new_col = vec![9, 10]; let new_col = vec![9, 10];
matrix.add_column(2, new_col); // Add at index 2 matrix.add_column(2, new_col);
// Should be:
// 1 3 9 5 7
// 2 4 10 6 8
assert_eq!(matrix.rows(), 2); assert_eq!(matrix.rows(), 2);
assert_eq!(matrix.cols(), 5); assert_eq!(matrix.cols(), 5);
@ -1342,17 +1347,12 @@ mod tests {
assert_eq!(matrix[(0, 4)], 7); // Shifted assert_eq!(matrix[(0, 4)], 7); // Shifted
assert_eq!(matrix[(1, 4)], 8); assert_eq!(matrix[(1, 4)], 8);
// Check underlying data (column-major) // Check underlying data
// Original: [1, 2, 3, 4, 5, 6, 7, 8]
// Add [9, 10] at index 2: [1, 2, 3, 4, 9, 10, 5, 6, 7, 8]
assert_eq!(matrix.data(), &[1, 2, 3, 4, 9, 10, 5, 6, 7, 8]); assert_eq!(matrix.data(), &[1, 2, 3, 4, 9, 10, 5, 6, 7, 8]);
// Add a column at the beginning // Add a column at the beginning
let new_col_start = vec![11, 12]; let new_col_start = vec![11, 12];
matrix.add_column(0, new_col_start); matrix.add_column(0, new_col_start);
// Should be:
// 11 1 3 9 5 7
// 12 2 4 10 6 8
assert_eq!(matrix.rows(), 2); assert_eq!(matrix.rows(), 2);
assert_eq!(matrix.cols(), 6); assert_eq!(matrix.cols(), 6);
assert_eq!(matrix[(0, 0)], 11); assert_eq!(matrix[(0, 0)], 11);
@ -1362,9 +1362,6 @@ mod tests {
// Add a column at the end // Add a column at the end
let new_col_end = vec![13, 14]; let new_col_end = vec![13, 14];
matrix.add_column(6, new_col_end); matrix.add_column(6, new_col_end);
// Should be:
// 11 1 3 9 5 7 13
// 12 2 4 10 6 8 14
assert_eq!(matrix.rows(), 2); assert_eq!(matrix.rows(), 2);
assert_eq!(matrix.cols(), 7); assert_eq!(matrix.cols(), 7);
assert_eq!(matrix[(0, 6)], 13); assert_eq!(matrix[(0, 6)], 13);
@ -1378,7 +1375,7 @@ mod tests {
#[test] #[test]
#[should_panic(expected = "add_column index 5 out of bounds for 4 columns")] #[should_panic(expected = "add_column index 5 out of bounds for 4 columns")]
fn test_add_column_out_of_bounds() { fn test_add_column_out_of_bounds() {
let mut matrix = static_test_matrix_2x4(); // 2x4 let mut matrix = static_test_matrix_2x4();
let new_col = vec![9, 10]; let new_col = vec![9, 10];
matrix.add_column(5, new_col); // Index 5 is out of bounds for 4 columns matrix.add_column(5, new_col); // Index 5 is out of bounds for 4 columns
} }
@ -1386,24 +1383,16 @@ mod tests {
#[test] #[test]
#[should_panic(expected = "column length mismatch")] #[should_panic(expected = "column length mismatch")]
fn test_add_column_length_mismatch() { fn test_add_column_length_mismatch() {
let mut matrix = static_test_matrix_2x4(); // 2x4 (2 rows) let mut matrix = static_test_matrix_2x4();
let new_col = vec![9, 10, 11]; // Wrong length let new_col = vec![9, 10, 11]; // Wrong length
matrix.add_column(0, new_col); matrix.add_column(0, new_col);
} }
#[test] #[test]
fn test_add_row() { fn test_add_row() {
let mut matrix = static_test_matrix_2x4(); // 2x4 let mut matrix = static_test_matrix_2x4();
// 1 3 5 7
// 2 4 6 8
let new_row = vec![9, 10, 11, 12]; let new_row = vec![9, 10, 11, 12];
matrix.add_row(1, new_row); // Add at index 1 matrix.add_row(1, new_row);
// Should be:
// 1 3 5 7
// 9 10 11 12
// 2 4 6 8
assert_eq!(matrix.rows(), 3); assert_eq!(matrix.rows(), 3);
assert_eq!(matrix.cols(), 4); assert_eq!(matrix.cols(), 4);
@ -1422,23 +1411,11 @@ mod tests {
assert_eq!(matrix[(2, 3)], 8); assert_eq!(matrix[(2, 3)], 8);
// Check underlying data (column-major) // Check underlying data (column-major)
// Original: [1, 2, 3, 4, 5, 6, 7, 8] (rows 0, 1)
// Add [9, 10, 11, 12] at index 1 (new row will be index 1, original row 1 becomes index 2)
// Col 0: [1, 9, 2]
// Col 1: [3, 10, 4]
// Col 2: [5, 11, 6]
// Col 3: [7, 12, 8]
// Data: [1, 9, 2, 3, 10, 4, 5, 11, 6, 7, 12, 8]
assert_eq!(matrix.data(), &[1, 9, 2, 3, 10, 4, 5, 11, 6, 7, 12, 8]); assert_eq!(matrix.data(), &[1, 9, 2, 3, 10, 4, 5, 11, 6, 7, 12, 8]);
// Add a row at the beginning // Add a row at the beginning
let new_row_start = vec![13, 14, 15, 16]; let new_row_start = vec![13, 14, 15, 16];
matrix.add_row(0, new_row_start); matrix.add_row(0, new_row_start);
// Should be:
// 13 14 15 16
// 1 3 5 7
// 9 10 11 12
// 2 4 6 8
assert_eq!(matrix.rows(), 4); assert_eq!(matrix.rows(), 4);
assert_eq!(matrix.cols(), 4); assert_eq!(matrix.cols(), 4);
assert_eq!(matrix[(0, 0)], 13); assert_eq!(matrix[(0, 0)], 13);
@ -1452,12 +1429,6 @@ mod tests {
// Add a row at the end // Add a row at the end
let new_row_end = vec![17, 18, 19, 20]; let new_row_end = vec![17, 18, 19, 20];
matrix.add_row(4, new_row_end); matrix.add_row(4, new_row_end);
// Should be:
// 13 14 15 16
// 1 3 5 7
// 9 10 11 12
// 2 4 6 8
// 17 18 19 20
assert_eq!(matrix.rows(), 5); assert_eq!(matrix.rows(), 5);
assert_eq!(matrix.cols(), 4); assert_eq!(matrix.cols(), 4);
assert_eq!(matrix[(4, 0)], 17); assert_eq!(matrix[(4, 0)], 17);
@ -1467,7 +1438,7 @@ mod tests {
#[test] #[test]
#[should_panic(expected = "add_row index 3 out of bounds for 2 rows")] #[should_panic(expected = "add_row index 3 out of bounds for 2 rows")]
fn test_add_row_out_of_bounds() { fn test_add_row_out_of_bounds() {
let mut matrix = static_test_matrix_2x4(); // 2x4 let mut matrix = static_test_matrix_2x4();
let new_row = vec![9, 10, 11, 12]; let new_row = vec![9, 10, 11, 12];
matrix.add_row(3, new_row); // Index 3 is out of bounds for 2 rows matrix.add_row(3, new_row); // Index 3 is out of bounds for 2 rows
} }
@ -1475,26 +1446,22 @@ mod tests {
#[test] #[test]
#[should_panic(expected = "row length mismatch")] #[should_panic(expected = "row length mismatch")]
fn test_add_row_length_mismatch() { fn test_add_row_length_mismatch() {
let mut matrix = static_test_matrix_2x4(); // 2x4 (4 cols) let mut matrix = static_test_matrix_2x4();
let new_row = vec![9, 10, 11]; // Wrong length let new_row = vec![9, 10, 11]; // Wrong length
matrix.add_row(0, new_row); matrix.add_row(0, new_row);
} }
#[test] #[test]
fn test_elementwise_add() { fn test_elementwise_add() {
let matrix1 = static_test_matrix(); // 3x3 let matrix1 = static_test_matrix();
let matrix2 = Matrix::from_vec(vec![9, 8, 7, 6, 5, 4, 3, 2, 1], 3, 3); // 3x3 let matrix2 = Matrix::from_vec(vec![9, 8, 7, 6, 5, 4, 3, 2, 1], 3, 3);
let result = &matrix1 + &matrix2; let result = &matrix1 + &matrix2;
assert_eq!(result.rows(), 3); assert_eq!(result.rows(), 3);
assert_eq!(result.cols(), 3); assert_eq!(result.cols(), 3);
// Expected: // Expect all 10s
// 1+9 4+6 7+3 => 10 10 10
// 2+8 5+5 8+2 => 10 10 10
// 3+7 6+4 9+1 => 10 10 10
// Column-major data: [10, 10, 10, 10, 10, 10, 10, 10, 10]
assert_eq!(result.data(), &[10, 10, 10, 10, 10, 10, 10, 10, 10]); assert_eq!(result.data(), &[10, 10, 10, 10, 10, 10, 10, 10, 10]);
assert_eq!(result[(0, 0)], 10); assert_eq!(result[(0, 0)], 10);
assert_eq!(result[(1, 1)], 10); assert_eq!(result[(1, 1)], 10);
@ -1503,19 +1470,14 @@ mod tests {
#[test] #[test]
fn test_elementwise_sub() { fn test_elementwise_sub() {
let matrix1 = static_test_matrix(); // 3x3 let matrix1 = static_test_matrix();
let matrix2 = Matrix::from_vec(vec![1, 1, 1, 2, 2, 2, 3, 3, 3], 3, 3); // 3x3 let matrix2 = Matrix::from_vec(vec![1, 1, 1, 2, 2, 2, 3, 3, 3], 3, 3);
let result = &matrix1 - &matrix2; let result = &matrix1 - &matrix2;
assert_eq!(result.rows(), 3); assert_eq!(result.rows(), 3);
assert_eq!(result.cols(), 3); assert_eq!(result.cols(), 3);
// Expected:
// 1-1 4-2 7-3 => 0 2 4
// 2-1 5-2 8-3 => 1 3 5
// 3-1 6-2 9-3 => 2 4 6
// Column-major data: [0, 1, 2, 2, 3, 4, 4, 5, 6]
assert_eq!(result.data(), &[0, 1, 2, 2, 3, 4, 4, 5, 6]); assert_eq!(result.data(), &[0, 1, 2, 2, 3, 4, 4, 5, 6]);
assert_eq!(result[(0, 0)], 0); assert_eq!(result[(0, 0)], 0);
assert_eq!(result[(1, 1)], 3); assert_eq!(result[(1, 1)], 3);
@ -1524,19 +1486,15 @@ mod tests {
#[test] #[test]
fn test_elementwise_mul() { fn test_elementwise_mul() {
let matrix1 = static_test_matrix(); // 3x3 let matrix1 = static_test_matrix();
let matrix2 = Matrix::from_vec(vec![1, 2, 3, 1, 2, 3, 1, 2, 3], 3, 3); // 3x3 let matrix2 = Matrix::from_vec(vec![1, 2, 3, 1, 2, 3, 1, 2, 3], 3, 3);
let result = &matrix1 * &matrix2; let result = &matrix1 * &matrix2;
assert_eq!(result.rows(), 3); assert_eq!(result.rows(), 3);
assert_eq!(result.cols(), 3); assert_eq!(result.cols(), 3);
// Expected: // Expected
// 1*1 4*1 7*1 => 1 4 7
// 2*2 5*2 8*2 => 4 10 16
// 3*3 6*3 9*3 => 9 18 27
// Column-major data: [1, 4, 9, 4, 10, 18, 7, 16, 27]
assert_eq!(result.data(), &[1, 4, 9, 4, 10, 18, 7, 16, 27]); assert_eq!(result.data(), &[1, 4, 9, 4, 10, 18, 7, 16, 27]);
assert_eq!(result[(0, 0)], 1); assert_eq!(result[(0, 0)], 1);
assert_eq!(result[(1, 1)], 10); assert_eq!(result[(1, 1)], 10);
@ -1545,19 +1503,14 @@ mod tests {
#[test] #[test]
fn test_elementwise_div() { fn test_elementwise_div() {
let matrix1 = static_test_matrix(); // 3x3 let matrix1 = static_test_matrix();
let matrix2 = Matrix::from_vec(vec![1, 1, 1, 2, 2, 2, 7, 8, 9], 3, 3); // 3x3 let matrix2 = Matrix::from_vec(vec![1, 1, 1, 2, 2, 2, 7, 8, 9], 3, 3);
let result = &matrix1 / &matrix2; // Integer division let result = &matrix1 / &matrix2; // Integer division
assert_eq!(result.rows(), 3); assert_eq!(result.rows(), 3);
assert_eq!(result.cols(), 3); assert_eq!(result.cols(), 3);
// Expected:
// 1/1 4/2 7/7 => 1 2 1
// 2/1 5/2 8/8 => 2 2 1 (integer division)
// 3/1 6/2 9/9 => 3 3 1
// Column-major data: [1, 2, 3, 2, 2, 3, 1, 1, 1]
assert_eq!(result.data(), &[1, 2, 3, 2, 2, 3, 1, 1, 1]); assert_eq!(result.data(), &[1, 2, 3, 2, 2, 3, 1, 1, 1]);
assert_eq!(result[(0, 0)], 1); assert_eq!(result[(0, 0)], 1);
assert_eq!(result[(1, 1)], 2); assert_eq!(result[(1, 1)], 2);
@ -1567,34 +1520,26 @@ mod tests {
#[test] #[test]
#[should_panic(expected = "Row count mismatch: left has 3 rows, right has 2 rows")] #[should_panic(expected = "Row count mismatch: left has 3 rows, right has 2 rows")]
fn test_elementwise_op_row_mismatch() { fn test_elementwise_op_row_mismatch() {
let matrix1 = static_test_matrix(); // 3x3 let matrix1 = static_test_matrix();
let matrix2 = static_test_matrix_2x4(); // 2x4 let matrix2 = static_test_matrix_2x4();
let _ = &matrix1 + &matrix2; // Should panic let _ = &matrix1 + &matrix2; // Should panic
} }
#[test] #[test]
#[should_panic(expected = "Row count mismatch: left has 3 rows, right has 2 ro")] #[should_panic(expected = "Row count mismatch: left has 3 rows, right has 2 ro")]
fn test_elementwise_op_col_mismatch() { fn test_elementwise_op_col_mismatch() {
let matrix1 = static_test_matrix(); // 3x3 let matrix1 = static_test_matrix();
let matrix2 = static_test_matrix_2x4(); // 2x4 let matrix2 = static_test_matrix_2x4();
let _ = &matrix1 * &matrix2; // Should panic let _ = &matrix1 * &matrix2; // Should panic
} }
#[test] #[test]
fn test_bitwise_and() { fn test_bitwise_and() {
let data1 = vec![true, false, true, false, true, false]; // 2x3 let data1 = vec![true, false, true, false, true, false];
let data2 = vec![true, true, false, false, true, true]; // 2x3 let data2 = vec![true, true, false, false, true, true];
let matrix1 = BoolMatrix::from_vec(data1, 2, 3); let matrix1 = BoolMatrix::from_vec(data1, 2, 3);
let matrix2 = BoolMatrix::from_vec(data2, 2, 3); let matrix2 = BoolMatrix::from_vec(data2, 2, 3);
// Expected column-major results:
// T & T = T
// F & T = F
// T & F = F
// F & F = F
// T & T = T
// F & T = F
// Data: [T, F, F, F, T, F]
let expected_data = vec![true, false, false, false, true, false]; let expected_data = vec![true, false, false, false, true, false];
let expected_matrix = BoolMatrix::from_vec(expected_data, 2, 3); let expected_matrix = BoolMatrix::from_vec(expected_data, 2, 3);
@ -1604,19 +1549,11 @@ mod tests {
#[test] #[test]
fn test_bitwise_or() { fn test_bitwise_or() {
let data1 = vec![true, false, true, false, true, false]; // 2x3 let data1 = vec![true, false, true, false, true, false];
let data2 = vec![true, true, false, false, true, true]; // 2x3 let data2 = vec![true, true, false, false, true, true];
let matrix1 = BoolMatrix::from_vec(data1, 2, 3); let matrix1 = BoolMatrix::from_vec(data1, 2, 3);
let matrix2 = BoolMatrix::from_vec(data2, 2, 3); let matrix2 = BoolMatrix::from_vec(data2, 2, 3);
// Expected column-major results:
// T | T = T
// F | T = T
// T | F = T
// F | F = F
// T | T = T
// F | T = T
// Data: [T, T, T, F, T, T]
let expected_data = vec![true, true, true, false, true, true]; let expected_data = vec![true, true, true, false, true, true];
let expected_matrix = BoolMatrix::from_vec(expected_data, 2, 3); let expected_matrix = BoolMatrix::from_vec(expected_data, 2, 3);
@ -1626,19 +1563,11 @@ mod tests {
#[test] #[test]
fn test_bitwise_xor() { fn test_bitwise_xor() {
let data1 = vec![true, false, true, false, true, false]; // 2x3 let data1 = vec![true, false, true, false, true, false];
let data2 = vec![true, true, false, false, true, true]; // 2x3 let data2 = vec![true, true, false, false, true, true];
let matrix1 = BoolMatrix::from_vec(data1, 2, 3); let matrix1 = BoolMatrix::from_vec(data1, 2, 3);
let matrix2 = BoolMatrix::from_vec(data2, 2, 3); let matrix2 = BoolMatrix::from_vec(data2, 2, 3);
// Expected column-major results:
// T ^ T = F
// F ^ T = T
// T ^ F = T
// F ^ F = F
// T ^ T = F
// F ^ T = T
// Data: [F, T, T, F, F, T]
let expected_data = vec![false, true, true, false, false, true]; let expected_data = vec![false, true, true, false, false, true];
let expected_matrix = BoolMatrix::from_vec(expected_data, 2, 3); let expected_matrix = BoolMatrix::from_vec(expected_data, 2, 3);
@ -1648,17 +1577,9 @@ mod tests {
#[test] #[test]
fn test_bitwise_not() { fn test_bitwise_not() {
let data = vec![true, false, true, false, true, false]; // 2x3 let data = vec![true, false, true, false, true, false];
let matrix = BoolMatrix::from_vec(data, 2, 3); let matrix = BoolMatrix::from_vec(data, 2, 3);
// Expected column-major results:
// !T = F
// !F = T
// !T = F
// !F = T
// !T = F
// !F = T
// Data: [F, T, F, T, F, T]
let expected_data = vec![false, true, false, true, false, true]; let expected_data = vec![false, true, false, true, false, true];
let expected_matrix = BoolMatrix::from_vec(expected_data, 2, 3); let expected_matrix = BoolMatrix::from_vec(expected_data, 2, 3);
@ -1669,8 +1590,8 @@ mod tests {
#[test] #[test]
#[should_panic(expected = "Column count mismatch: left has 2 columns, right has 3 columns")] #[should_panic(expected = "Column count mismatch: left has 2 columns, right has 3 columns")]
fn test_bitwise_op_row_mismatch() { fn test_bitwise_op_row_mismatch() {
let data1 = vec![true, false, true, false]; // 2x2 let data1 = vec![true, false, true, false];
let data2 = vec![true, true, false, false, true, true]; // 2x3 let data2 = vec![true, true, false, false, true, true];
let matrix1 = BoolMatrix::from_vec(data1, 2, 2); let matrix1 = BoolMatrix::from_vec(data1, 2, 2);
let matrix2 = BoolMatrix::from_vec(data2, 2, 3); let matrix2 = BoolMatrix::from_vec(data2, 2, 3);
let _ = &matrix1 & &matrix2; // Should panic let _ = &matrix1 & &matrix2; // Should panic
@ -1679,8 +1600,8 @@ mod tests {
#[test] #[test]
#[should_panic(expected = "Column count mismatch: left has 2 columns, right has 3 columns")] #[should_panic(expected = "Column count mismatch: left has 2 columns, right has 3 columns")]
fn test_bitwise_op_col_mismatch() { fn test_bitwise_op_col_mismatch() {
let data1 = vec![true, false, true, false]; // 2x2 let data1 = vec![true, false, true, false];
let data2 = vec![true, true, false, false, true, true]; // 2x3 let data2 = vec![true, true, false, false, true, true];
let matrix1 = BoolMatrix::from_vec(data1, 2, 2); let matrix1 = BoolMatrix::from_vec(data1, 2, 2);
let matrix2 = BoolMatrix::from_vec(data2, 2, 3); let matrix2 = BoolMatrix::from_vec(data2, 2, 3);
let _ = &matrix1 | &matrix2; // Should panic let _ = &matrix1 | &matrix2; // Should panic
@ -1695,7 +1616,7 @@ mod tests {
"c".to_string(), "c".to_string(),
"d".to_string(), "d".to_string(),
]; ];
let matrix = StringMatrix::from_vec(data.clone(), 2, 2); // 2x2 let matrix = StringMatrix::from_vec(data.clone(), 2, 2);
assert_eq!(matrix[(0, 0)], "a".to_string()); assert_eq!(matrix[(0, 0)], "a".to_string());
assert_eq!(matrix[(1, 0)], "b".to_string()); assert_eq!(matrix[(1, 0)], "b".to_string());
@ -1707,12 +1628,10 @@ mod tests {
matrix[(0, 0)] = "hello".to_string(); matrix[(0, 0)] = "hello".to_string();
assert_eq!(matrix[(0, 0)], "hello".to_string()); assert_eq!(matrix[(0, 0)], "hello".to_string());
// Test add_column (requires Clone) // Test add_column
let new_col = vec!["e".to_string(), "f".to_string()]; let new_col = vec!["e".to_string(), "f".to_string()];
matrix.add_column(1, new_col); // Add at index 1 matrix.add_column(1, new_col);
// Should be:
// hello c d
// b e f
assert_eq!(matrix.rows(), 2); assert_eq!(matrix.rows(), 2);
assert_eq!(matrix.cols(), 3); assert_eq!(matrix.cols(), 3);
assert_eq!(matrix[(0, 0)], "hello".to_string()); assert_eq!(matrix[(0, 0)], "hello".to_string());
@ -1724,11 +1643,8 @@ mod tests {
// Test add_row (requires Clone) // Test add_row (requires Clone)
let new_row = vec!["g".to_string(), "h".to_string(), "i".to_string()]; let new_row = vec!["g".to_string(), "h".to_string(), "i".to_string()];
matrix.add_row(0, new_row); // Add at index 0 matrix.add_row(0, new_row);
// Should be:
// g h i
// hello e c
// b f d
assert_eq!(matrix.rows(), 3); assert_eq!(matrix.rows(), 3);
assert_eq!(matrix.cols(), 3); assert_eq!(matrix.cols(), 3);
assert_eq!(matrix[(0, 0)], "g".to_string()); assert_eq!(matrix[(0, 0)], "g".to_string());
@ -1740,8 +1656,8 @@ mod tests {
#[test] #[test]
fn test_float_matrix_ops() { fn test_float_matrix_ops() {
let data1 = vec![1.0, 2.0, 3.0, 4.0]; // 2x2 let data1 = vec![1.0, 2.0, 3.0, 4.0];
let data2 = vec![0.5, 1.5, 2.5, 3.5]; // 2x2 let data2 = vec![0.5, 1.5, 2.5, 3.5];
let matrix1 = FloatMatrix::from_vec(data1, 2, 2); let matrix1 = FloatMatrix::from_vec(data1, 2, 2);
let matrix2 = FloatMatrix::from_vec(data2, 2, 2); let matrix2 = FloatMatrix::from_vec(data2, 2, 2);
@ -1750,17 +1666,13 @@ mod tests {
let prod = &matrix1 * &matrix2; let prod = &matrix1 * &matrix2;
let div = &matrix1 / &matrix2; let div = &matrix1 / &matrix2;
// Check sums (col-major): [1.5, 3.5, 5.5, 7.5]
assert_eq!(sum.data(), &[1.5, 3.5, 5.5, 7.5]); assert_eq!(sum.data(), &[1.5, 3.5, 5.5, 7.5]);
// Check diffs (col-major): [0.5, 0.5, 0.5, 0.5]
assert_eq!(diff.data(), &[0.5, 0.5, 0.5, 0.5]); assert_eq!(diff.data(), &[0.5, 0.5, 0.5, 0.5]);
// Check prods (col-major): [0.5, 3.0, 7.5, 14.0]
assert_eq!(prod.data(), &[0.5, 3.0, 7.5, 14.0]); assert_eq!(prod.data(), &[0.5, 3.0, 7.5, 14.0]);
// Check divs (col-major): [2.0, 1.333..., 1.2, 1.14...] // Check divs (col-major): [2.0, 1.333..., 1.2, 1.14...]
// Using element access for more specific checks on floating point results
assert_eq!(div.rows(), 2); assert_eq!(div.rows(), 2);
assert_eq!(div.cols(), 2); assert_eq!(div.cols(), 2);
assert!((div[(0, 0)] - 1.0 / 0.5).abs() < 1e-9); // 2.0 assert!((div[(0, 0)] - 1.0 / 0.5).abs() < 1e-9); // 2.0
@ -1785,20 +1697,18 @@ mod tests {
// Perform the swap directly on the matrix // Perform the swap directly on the matrix
matrix.swap_columns(0, 2); // Swap column 0 and column 2 matrix.swap_columns(0, 2); // Swap column 0 and column 2
// --- Assertions --- // Verify the dimensions are unchanged
// 1. Verify the dimensions are unchanged
assert_eq!(matrix.rows(), 3, "Matrix rows should remain unchanged"); assert_eq!(matrix.rows(), 3, "Matrix rows should remain unchanged");
assert_eq!(matrix.cols(), 3, "Matrix cols should remain unchanged"); assert_eq!(matrix.cols(), 3, "Matrix cols should remain unchanged");
// 2. Verify the column that was NOT swapped is unchanged // Verify the column that was NOT swapped is unchanged
assert_eq!( assert_eq!(
matrix.column(1), matrix.column(1),
initial_col1_data.as_slice(), // Comparing slice to slice initial_col1_data.as_slice(), // Comparing slice to slice
"Column 1 data should be unchanged" "Column 1 data should be unchanged"
); );
// 3. Verify the data swap occurred correctly using the COLUMN ACCESSOR // Verify the data swap occurred correctly using the COLUMN ACCESSOR
// The data originally at index 0 should now be at index 2 // The data originally at index 0 should now be at index 2
assert_eq!( assert_eq!(
matrix.column(2), matrix.column(2),
@ -1812,16 +1722,14 @@ mod tests {
"Column 0 should now contain the original data from column 2" "Column 0 should now contain the original data from column 2"
); );
// 4. (Optional but useful) Verify the underlying raw data vector // Verify the underlying raw data vector
// Original data: [1, 2, 3, 4, 5, 6, 7, 8, 9]
// Expected data after swapping col 0 and col 2: [7, 8, 9, 4, 5, 6, 1, 2, 3]
assert_eq!( assert_eq!(
matrix.data(), matrix.data(),
&[7, 8, 9, 4, 5, 6, 1, 2, 3], &[7, 8, 9, 4, 5, 6, 1, 2, 3],
"Underlying data vector is incorrect after swap" "Underlying data vector is incorrect after swap"
); );
// 5. Test swapping with self (should be a no-op) // Test swapping with self (should be a no-op)
let state_before_self_swap = matrix.clone(); let state_before_self_swap = matrix.clone();
matrix.swap_columns(1, 1); matrix.swap_columns(1, 1);
assert_eq!( assert_eq!(
@ -1829,7 +1737,7 @@ mod tests {
"Swapping a column with itself should not change the matrix" "Swapping a column with itself should not change the matrix"
); );
// 6. Test swapping adjacent columns // Test swapping adjacent columns
let mut matrix2 = create_test_matrix_i32(); let mut matrix2 = create_test_matrix_i32();
let initial_col0_data_m2 = matrix2.column(0).to_vec(); let initial_col0_data_m2 = matrix2.column(0).to_vec();
let initial_col1_data_m2 = matrix2.column(1).to_vec(); let initial_col1_data_m2 = matrix2.column(1).to_vec();

View File

@ -12,6 +12,9 @@ pub trait SeriesOps {
where where
F: FnMut(&[f64]) -> U; F: FnMut(&[f64]) -> U;
fn matrix_mul(&self, other: &Self) -> FloatMatrix;
fn dot(&self, other: &Self) -> FloatMatrix;
fn sum_vertical(&self) -> Vec<f64>; fn sum_vertical(&self) -> Vec<f64>;
fn sum_horizontal(&self) -> Vec<f64>; fn sum_horizontal(&self) -> Vec<f64>;
@ -139,11 +142,37 @@ impl SeriesOps for FloatMatrix {
let data = self.data().iter().map(|v| v.is_nan()).collect(); let data = self.data().iter().map(|v| v.is_nan()).collect();
BoolMatrix::from_vec(data, self.rows(), self.cols()) BoolMatrix::from_vec(data, self.rows(), self.cols())
} }
}
fn matrix_mul(&self, other: &Self) -> FloatMatrix {
let (m, n) = (self.rows(), self.cols());
let (n2, p) = (other.rows(), other.cols());
assert_eq!(
n, n2,
"Cannot multiply: left is {}x{}, right is {}x{}",
m, n, n2, p
);
// Column-major addressing: element (row i, col j) lives at j * m + i
let mut data = vec![0.0; m * p];
for i in 0..m {
for j in 0..p {
let mut sum = 0.0;
for k in 0..n {
sum += self[(i, k)] * other[(k, j)];
}
data[j * m + i] = sum; // <-- fixed index
}
}
FloatMatrix::from_vec(data, m, p)
}
fn dot(&self, other: &Self) -> FloatMatrix {
self.matrix_mul(other)
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
// Helper function to create a FloatMatrix for SeriesOps testing // Helper function to create a FloatMatrix for SeriesOps testing
@ -156,6 +185,22 @@ mod tests {
FloatMatrix::from_vec(data, 3, 3) FloatMatrix::from_vec(data, 3, 3)
} }
fn create_float_test_matrix_4x4() -> FloatMatrix {
// 4x4 matrix (column-major) with some NaNs
// 1.0 5.0 9.0 13.0
// 2.0 NaN 10.0 NaN
// 3.0 6.0 NaN 14.0
// NaN 7.0 11.0 NaN
// first make array with 16 elements
FloatMatrix::from_vec(
(0..16)
.map(|i| if i % 5 == 0 { f64::NAN } else { i as f64 })
.collect(),
4,
4,
)
}
// --- Tests for SeriesOps (FloatMatrix) --- // --- Tests for SeriesOps (FloatMatrix) ---
#[test] #[test]
@ -256,6 +301,48 @@ mod tests {
assert_eq!(matrix.is_nan(), expected_matrix); assert_eq!(matrix.is_nan(), expected_matrix);
} }
#[test]
fn test_series_ops_matrix_mul() {
let a = FloatMatrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2); // 2x2 matrix
let b = FloatMatrix::from_vec(vec![5.0, 6.0, 7.0, 8.0], 2, 2); // 2x2 matrix
// result should be: 23, 34, 31, 46
let expected = FloatMatrix::from_vec(vec![23.0, 34.0, 31.0, 46.0], 2, 2);
assert_eq!(a.matrix_mul(&b), expected);
assert_eq!(a.dot(&b), a.matrix_mul(&b)); // dot should be the same as matrix_mul for FloatMatrix
}
#[test]
fn test_series_ops_matrix_mul_with_nans() {
let a = create_float_test_matrix(); // 3x3 matrix with some NaNs
let b = create_float_test_matrix(); // 3x3 matrix with some NaNs
let mut result_vec = Vec::new();
result_vec.push(30.0);
for _ in 1..9 {
result_vec.push(f64::NAN);
}
let expected = FloatMatrix::from_vec(result_vec, 3, 3);
let result = a.matrix_mul(&b);
assert_eq!(result.is_nan(), expected.is_nan());
assert_eq!(
result.count_nan_horizontal(),
expected.count_nan_horizontal()
);
assert_eq!(result.count_nan_vertical(), expected.count_nan_vertical());
assert_eq!(result[(0, 0)], expected[(0, 0)]);
}
#[test]
#[should_panic(expected = "Cannot multiply: left is 3x3, right is 4x4")]
fn test_series_ops_matrix_mul_errors() {
let a = create_float_test_matrix();
let b = create_float_test_matrix_4x4();
a.dot(&b); // This should panic due to dimension mismatch
}
// --- Edge Cases for SeriesOps --- // --- Edge Cases for SeriesOps ---
#[test] #[test]
@ -286,7 +373,7 @@ mod tests {
#[test] #[test]
fn test_series_ops_1xn_matrix() { fn test_series_ops_1xn_matrix() {
let matrix = FloatMatrix::from_vec(vec![1.0, f64::NAN, 3.0, 4.0], 1, 4); // 1 row, 4 cols let matrix = FloatMatrix::from_vec(vec![1.0, f64::NAN, 3.0, 4.0], 1, 4); // 1 row, 4 cols
// Data: [1.0, NaN, 3.0, 4.0] // Data: [1.0, NaN, 3.0, 4.0]
// Vertical (sums/prods/counts per column - each col is just one element) // Vertical (sums/prods/counts per column - each col is just one element)
assert_eq!(matrix.sum_vertical(), vec![1.0, 0.0, 3.0, 4.0]); // NaN sum is 0 assert_eq!(matrix.sum_vertical(), vec![1.0, 0.0, 3.0, 4.0]); // NaN sum is 0
@ -316,7 +403,7 @@ mod tests {
#[test] #[test]
fn test_series_ops_nx1_matrix() { fn test_series_ops_nx1_matrix() {
let matrix = FloatMatrix::from_vec(vec![1.0, 2.0, f64::NAN, 4.0], 4, 1); // 4 rows, 1 col let matrix = FloatMatrix::from_vec(vec![1.0, 2.0, f64::NAN, 4.0], 4, 1); // 4 rows, 1 col
// Data: [1.0, 2.0, NaN, 4.0] // Data: [1.0, 2.0, NaN, 4.0]
// Vertical (sums/prods/counts for the single column) // Vertical (sums/prods/counts for the single column)
// Col 0: 1.0 + 2.0 + NaN + 4.0 = 7.0 // Col 0: 1.0 + 2.0 + NaN + 4.0 = 7.0