diff --git a/.github/workflows/docs-and-testcov.yml b/.github/workflows/docs-and-testcov.yml index 68bd58c..632b7e6 100644 --- a/.github/workflows/docs-and-testcov.yml +++ b/.github/workflows/docs-and-testcov.yml @@ -105,10 +105,21 @@ jobs: > last-commit-date.json - name: Download last available benchmark report + env: + GH_TOKEN: ${{ secrets.CUSTOM_GH_TOKEN }} run: | - artifact_url=$(gh api -H "Accept: application/vnd.github+json" \ - /repos/${{ github.repository }}/actions/artifacts \ - | jq -r '.artifacts[] | select(.name | startswith("benchmark-reports")) | .archive_download_url' | head -n 1) + artifact_url=$( + curl -sSL \ + -H "Accept: application/vnd.github+json" \ + -H "Authorization: Bearer ${GH_TOKEN}" \ + "https://api.github.com/repos/${{ github.repository }}/actions/artifacts" \ + | jq -r ' + .artifacts[] + | select(.name | startswith("benchmark-reports")) + | .archive_download_url + ' \ + | head -n 1 + ) if [ -z "$artifact_url" ]; then echo "No benchmark artifact found!" @@ -117,7 +128,7 @@ jobs: exit 0 fi - curl -L -H "Authorization: Bearer ${{ secrets.CUSTOM_GH_TOKEN }}" \ + curl -L -H "Authorization: Bearer ${GH_TOKEN}" \ "$artifact_url" -o benchmark-report.zip # Print all files in the current directory diff --git a/Cargo.toml b/Cargo.toml index f8343d2..d3d44f8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,9 @@ crate-type = ["cdylib", "lib"] chrono = "^0.4.10" criterion = { version = "0.5", features = ["html_reports"], optional = true } +[dev-dependencies] +rand = "^0.9.1" + [features] bench = ["dep:criterion"] diff --git a/README.md b/README.md index 708b636..496ec68 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,6 @@ -# Rustframe rustframe +# rustframe + + @@ -109,3 +111,15 @@ assert!(check); ### More examples See the [examples](./examples/) directory for some demonstrations of Rustframe's syntax and functionality. + +To run the examples, use: + +```bash +cargo run --example +``` + +E.g. to run the `game_of_life` example: + +```bash +cargo run --example game_of_life +``` diff --git a/benches/benchmarks.rs b/benches/benchmarks.rs index e8954c8..741e5aa 100644 --- a/benches/benchmarks.rs +++ b/benches/benchmarks.rs @@ -4,7 +4,7 @@ use criterion::{criterion_group, criterion_main, Criterion}; use rustframe::{ frame::{Frame, RowIndex}, - matrix::{BoolMatrix, Matrix, SeriesOps}, + matrix::{Axis, BoolMatrix, Matrix, SeriesOps}, utils::{DateFreq, DatesList}, }; use std::time::Duration; @@ -109,6 +109,47 @@ fn matrix_operations_benchmark(c: &mut Criterion, sizes: &[usize]) { let _result = &ma / 2.0; }); }); + c.bench_function( + &format!("matrix matrix_multiply ({}x{})", size, size), + |b| { + b.iter(|| { + let _result = ma.matrix_mul(&ma); + }); + }, + ); + c.bench_function(&format!("matrix sum_horizontal ({}x{})", size, size), |b| { + b.iter(|| { + let _result = ma.sum_horizontal(); + }); + }); + c.bench_function(&format!("matrix sum_vertical ({}x{})", size, size), |b| { + b.iter(|| { + let _result = ma.sum_vertical(); + }); + }); + c.bench_function( + &format!("matrix prod_horizontal ({}x{})", size, size), + |b| { + b.iter(|| { + let _result = ma.prod_horizontal(); + }); + }, + ); + c.bench_function(&format!("matrix prod_vertical ({}x{})", size, size), |b| { + b.iter(|| { + let _result = ma.prod_vertical(); + }); + }); + c.bench_function(&format!("matrix apply_axis ({}x{})", size, size), |b| { + b.iter(|| { + let _result = ma.apply_axis(Axis::Col, |col| col.iter().sum::()); + }); + }); + c.bench_function(&format!("matrix transpose ({}x{})", size, size), |b| { + b.iter(|| { + let _result = ma.transpose(); + }); + }); } for &size in sizes { @@ -187,6 +228,12 @@ fn benchmark_frame_operations(c: &mut Criterion, sizes: &[usize]) { }); }); + c.bench_function(&format!("frame matrix_multiply ({}x{})", size, size), |b| { + b.iter(|| { + let _result = fa.matrix_mul(&fb); + }); + }); + c.bench_function(&format!("frame sum_horizontal ({}x{})", size, size), |b| { b.iter(|| { let _result = fa.sum_horizontal(); @@ -207,6 +254,16 @@ fn benchmark_frame_operations(c: &mut Criterion, sizes: &[usize]) { let _result = fa.prod_vertical(); }); }); + c.bench_function(&format!("frame apply_axis ({}x{})", size, size), |b| { + b.iter(|| { + let _result = fa.apply_axis(Axis::Col, |col| col.iter().sum::()); + }); + }); + c.bench_function(&format!("frame transpose ({}x{})", size, size), |b| { + b.iter(|| { + let _result = fa.transpose(); + }); + }); } } diff --git a/example/game_of_life.rs b/examples/game_of_life.rs similarity index 76% rename from example/game_of_life.rs rename to examples/game_of_life.rs index 0af190f..eb88afd 100644 --- a/example/game_of_life.rs +++ b/examples/game_of_life.rs @@ -1,7 +1,113 @@ -// src/gol.rs - use rand::{self, Rng}; -use rustframe::matrix::{BoolMatrix, IntMatrix, Matrix}; +use rustframe::matrix::{BoolMatrix, BoolOps, IntMatrix, Matrix}; +use std::{thread, time}; + +const BOARD_SIZE: usize = 50; // Size of the board (50x50) +const TICK_DURATION_MS: u64 = 10; // Milliseconds per frame + +fn main() { + // Initialize the game board. + // This demonstrates `BoolMatrix::from_vec`. + let mut current_board = + BoolMatrix::from_vec(vec![false; BOARD_SIZE * BOARD_SIZE], BOARD_SIZE, BOARD_SIZE); + + let primes = generate_primes((BOARD_SIZE * BOARD_SIZE) as i32); + + add_simulated_activity(&mut current_board, BOARD_SIZE); + + let mut generation_count: u32 = 0; + // `previous_board_state` will store a clone of the board. + // This demonstrates `Matrix::clone()` and later `PartialEq` for `Matrix`. + let mut previous_board_state: Option = None; + let mut board_hashes = Vec::new(); + // let mut print_board_bool = true; + let mut print_bool_int = 0; + + loop { + // print!("{}[2J", 27 as char); // Clear screen and move cursor to top-left + + // if print_board_bool { + if print_bool_int % 10 == 0 { + print!("{}[2J", 27 as char); + println!("Conway's Game of Life - Generation: {}", generation_count); + + print_board(¤t_board); + println!("Alive cells: {}", ¤t_board.count()); + + // print_board_bool = false; + print_bool_int = 0; + } else { + // print_board_bool = true; + print_bool_int += 1; + } + // `current_board.count()` demonstrates a method from `BoolOps`. + board_hashes.push(hash_board(¤t_board, primes.clone())); + if detect_stable_state(¤t_board, &previous_board_state) { + println!( + "\nStable state detected at generation {}.", + generation_count + ); + add_simulated_activity(&mut current_board, BOARD_SIZE); + } + if detect_repeating_state(&mut board_hashes) { + println!( + "\nRepeating state detected at generation {}.", + generation_count + ); + add_simulated_activity(&mut current_board, BOARD_SIZE); + } + if !¤t_board.any() { + println!("\nExtinction at generation {}.", generation_count); + add_simulated_activity(&mut current_board, BOARD_SIZE); + } + + // `current_board.clone()` demonstrates `Clone` for `Matrix`. + previous_board_state = Some(current_board.clone()); + + // This is the core call to your game logic. + let next_board = game_of_life_next_frame(¤t_board); + current_board = next_board; + + generation_count += 1; + thread::sleep(time::Duration::from_millis(TICK_DURATION_MS)); + + // if generation_count > 500 { // Optional limit + // println!("\nReached generation limit."); + // break; + // } + } +} + +/// Prints the Game of Life board to the console. +/// +/// - `board`: A reference to the `BoolMatrix` representing the current game state. +/// This function demonstrates `board.rows()`, `board.cols()`, and `board[(r, c)]` (Index trait). +fn print_board(board: &BoolMatrix) { + let mut print_str = String::new(); + print_str.push_str("+"); + for _ in 0..board.cols() { + print_str.push_str("--"); + } + print_str.push_str("+\n"); + for r in 0..board.rows() { + print_str.push_str("| "); + for c in 0..board.cols() { + if board[(r, c)] { + // Using Index trait for Matrix + print_str.push_str("██"); + } else { + print_str.push_str(" "); + } + } + print_str.push_str(" |\n"); + } + print_str.push_str("+"); + for _ in 0..board.cols() { + print_str.push_str("--"); + } + print_str.push_str("+\n\n"); + print!("{}", print_str); +} /// Helper function to create a shifted version of the game board. /// (Using the version provided by the user) diff --git a/src/frame/base.rs b/src/frame/base.rs index 08a969d..5b485a4 100644 --- a/src/frame/base.rs +++ b/src/frame/base.rs @@ -471,6 +471,12 @@ impl Frame { deleted_data } + + /// Returns a new `Matrix` that is the transpose of the current frame's matrix. + pub fn transpose(&self) -> Matrix { + self.matrix.transpose() + } + /// Sorts columns alphabetically by name, preserving data associations. pub fn sort_columns(&mut self) { let n = self.column_names.len(); diff --git a/src/frame/ops.rs b/src/frame/ops.rs index c35625e..dea2396 100644 --- a/src/frame/ops.rs +++ b/src/frame/ops.rs @@ -20,6 +20,13 @@ impl SeriesOps for Frame { { self.matrix().apply_axis(axis, f) } + fn matrix_mul(&self, other: &Self) -> FloatMatrix { + self.matrix().matrix_mul(other.matrix()) + } + + fn dot(&self, other: &Self) -> FloatMatrix { + self.matrix().dot(other.matrix()) + } delegate_to_matrix!( sum_vertical -> Vec, @@ -128,6 +135,17 @@ mod tests { ); assert_eq!(frame.is_nan(), frame.matrix().is_nan()); assert_eq!(frame.apply_axis(Axis::Row, |x| x[0] + x[1]), vec![4.0, 6.0]); + + + assert_eq!(frame.matrix_mul(&frame), frame.matrix().matrix_mul(&frame.matrix())); + assert_eq!(frame.dot(&frame), frame.matrix().dot(&frame.matrix())); + + // test transpose - returns a matrix. + let frame_transposed_mat = frame.transpose(); + let frame_mat_transposed = frame.matrix().transpose(); + assert_eq!(frame_transposed_mat, frame_mat_transposed); + assert_eq!(frame.matrix(), &frame.matrix().transpose().transpose()); + } #[test] @@ -152,4 +170,7 @@ mod tests { vec![false, false] ); } + + + } diff --git a/src/matrix/mat.rs b/src/matrix/mat.rs index 18eb804..64bf6de 100644 --- a/src/matrix/mat.rs +++ b/src/matrix/mat.rs @@ -76,7 +76,7 @@ impl Matrix { self.data } - /// Creates a new Vec containing the matrix data (cloned). + /// Creates a new `Vec` containing the matrix data (cloned). pub fn to_vec(&self) -> Vec { self.data.clone() } @@ -168,7 +168,12 @@ impl Matrix { /// Deletes a column from the matrix. Panics on out-of-bounds. /// This is O(N) where N is the number of elements. pub fn delete_column(&mut self, col: usize) { - assert!(col < self.cols, "column index {} out of bounds for {} columns", col, self.cols); + assert!( + col < self.cols, + "column index {} out of bounds for {} columns", + col, + self.cols + ); let start = col * self.rows; self.data.drain(start..start + self.rows); // Efficient removal self.cols -= 1; @@ -203,6 +208,22 @@ impl Matrix { self.data = new_data; self.rows = new_rows; } + + pub fn transpose(&self) -> Matrix { + let (m, n) = (self.rows, self.cols); + let mut transposed_data = Vec::with_capacity(m * n); + + // In the transposed matrix the old rows become the new columns. + for j in 0..m { + // new column index = old row index + for i in 0..n { + // new row index = old col index + transposed_data.push(self[(j, i)].clone()); // A(T)[i,j] = A[j,i] + } + } + + Matrix::from_vec(transposed_data, n, m) // size is n × m + } } impl Matrix { @@ -238,8 +259,19 @@ impl Matrix { /// Adds a row to the matrix at the specified index. Panics if index > rows or length mismatch. /// This is O(N) where N is the number of elements, as it rebuilds the data vec. pub fn add_row(&mut self, index: usize, row: Vec) { - assert!(index <= self.rows, "add_row index {} out of bounds for {} rows", index, self.rows); - assert_eq!(row.len(), self.cols, "row length mismatch: expected {} (cols), got {}", self.cols, row.len()); + assert!( + index <= self.rows, + "add_row index {} out of bounds for {} rows", + index, + self.rows + ); + assert_eq!( + row.len(), + self.cols, + "row length mismatch: expected {} (cols), got {}", + self.cols, + row.len() + ); if self.cols == 0 && self.rows == 0 { // Special case: adding first row to empty matrix @@ -624,7 +656,6 @@ pub type BoolMatrix = Matrix; pub type IntMatrix = Matrix; pub type StringMatrix = Matrix; - // --- Unit Tests --- #[cfg(test)] @@ -903,6 +934,40 @@ mod tests { assert_eq!(matrix[(1, 2)], 6); // Second row, third col } + #[test] + fn test_transpose() { + let matrix = static_test_matrix(); + let transposed = matrix.transpose(); + let round_triped = transposed.transpose(); + assert_eq!( + round_triped, matrix, + "Transposing twice should return original matrix" + ); + for r in 0..matrix.rows() { + for c in 0..matrix.cols() { + assert_eq!(matrix[(r, c)], transposed[(c, r)]); + } + } + } + + #[test] + fn test_transpose_big() { + let data: Vec = (1..=20000).collect(); // + let matrix = Matrix::from_vec(data, 100, 200); + let transposed = matrix.transpose(); + assert_eq!(transposed.rows(), 200); + assert_eq!(transposed.cols(), 100); + assert_eq!(transposed.data().len(), 20000); + assert_eq!(transposed[(0, 0)], 1); + + let round_trip = transposed.transpose(); + + assert_eq!( + round_trip, matrix, + "Transposing back should return original matrix" + ); + } + #[test] #[should_panic(expected = "data length mismatch")] fn test_from_vec_wrong_length() { @@ -999,20 +1064,20 @@ mod tests { #[test] #[should_panic(expected = "index out of bounds")] fn test_index_out_of_bounds_row_alt() { - let matrix = static_test_matrix(); // 3x3 + let matrix = static_test_matrix(); let _ = matrix[(3, 0)]; } #[test] #[should_panic(expected = "index out of bounds")] fn test_index_out_of_bounds_col_alt() { - let matrix = static_test_matrix(); // 3x3 + let matrix = static_test_matrix(); let _ = matrix[(0, 3)]; } #[test] fn test_index_mut_and_get_mut() { - let mut matrix = static_test_matrix(); // 3x3 + let mut matrix = static_test_matrix(); matrix[(0, 0)] = 10; matrix[(1, 1)] = 20; @@ -1028,33 +1093,26 @@ mod tests { assert_eq!(matrix[(0, 1)], 15); assert_eq!(matrix[(2, 1)], 25); - // Check underlying data consistency (column-major) - // Should be: - // 10 15 7 - // 2 20 8 - // 3 25 30 assert_eq!(matrix.data(), &[10, 2, 3, 15, 20, 25, 7, 8, 30]); } #[test] #[should_panic(expected = "index out of bounds")] fn test_index_mut_out_of_bounds_row() { - let mut matrix = static_test_matrix(); // 3x3 + let mut matrix = static_test_matrix(); matrix[(3, 0)] = 99; } #[test] #[should_panic(expected = "index out of bounds")] fn test_index_mut_out_of_bounds_col() { - let mut matrix = static_test_matrix(); // 3x3 + let mut matrix = static_test_matrix(); matrix[(0, 3)] = 99; } #[test] fn test_column() { - let matrix = static_test_matrix_2x4(); // 2x4 - // 1 3 5 7 - // 2 4 6 8 + let matrix = static_test_matrix_2x4(); assert_eq!(matrix.column(0), &[1, 2]); assert_eq!(matrix.column(1), &[3, 4]); @@ -1065,15 +1123,13 @@ mod tests { #[test] #[should_panic(expected = "column index 4 out of bounds for 4 columns")] fn test_column_out_of_bounds() { - let matrix = static_test_matrix_2x4(); // 2x4 + let matrix = static_test_matrix_2x4(); matrix.column(4); } #[test] fn test_column_mut() { - let mut matrix = static_test_matrix_2x4(); // 2x4 - // 1 3 5 7 - // 2 4 6 8 + let mut matrix = static_test_matrix_2x4(); let col1_mut = matrix.column_mut(1); col1_mut[0] = 30; @@ -1088,25 +1144,19 @@ mod tests { assert_eq!(matrix[(0, 3)], 70); assert_eq!(matrix[(1, 3)], 8); // Unchanged - // Check underlying data (column-major) - // Should be: - // 1 30 5 70 - // 2 40 6 8 assert_eq!(matrix.data(), &[1, 2, 30, 40, 5, 6, 70, 8]); } #[test] #[should_panic(expected = "column index 4 out of bounds for 4 columns")] fn test_column_mut_out_of_bounds() { - let mut matrix = static_test_matrix_2x4(); // 2x4 + let mut matrix = static_test_matrix_2x4(); matrix.column_mut(4); } #[test] fn test_iter_columns() { - let matrix = static_test_matrix_2x4(); // 2x4 - // 1 3 5 7 - // 2 4 6 8 + let matrix = static_test_matrix_2x4(); let cols: Vec<&[i32]> = matrix.iter_columns().collect(); assert_eq!(cols.len(), 4); @@ -1118,9 +1168,7 @@ mod tests { #[test] fn test_iter_rows() { - let matrix = static_test_matrix_2x4(); // 2x4 - // 1 3 5 7 - // 2 4 6 8 + let matrix = static_test_matrix_2x4(); let rows: Vec> = matrix .iter_rows() @@ -1134,10 +1182,7 @@ mod tests { // test data_mut #[test] fn test_data_mut() { - let mut matrix = static_test_matrix(); // 3x3 - // 1 4 7 - // 2 5 8 - // 3 6 9 + let mut matrix = static_test_matrix(); let data_mut = matrix.data_mut(); data_mut[0] = 10; @@ -1149,10 +1194,7 @@ mod tests { #[test] fn test_matrix_row_get_and_iter() { - let matrix = static_test_matrix_2x4(); // 2x4 - // 1 3 5 7 - // 2 4 6 8 - + let matrix = static_test_matrix_2x4(); let row0 = matrix.iter_rows().next().unwrap(); assert_eq!(*row0.get(0), 1); assert_eq!(*row0.get(1), 3); @@ -1169,17 +1211,9 @@ mod tests { #[test] fn test_swap_columns() { - let mut matrix = static_test_matrix(); // 3x3 - // 1 4 7 - // 2 5 8 - // 3 6 9 + let mut matrix = static_test_matrix(); - matrix.swap_columns(0, 2); // Swap first and last - - // Should be: - // 7 4 1 - // 8 5 2 - // 9 6 3 + matrix.swap_columns(0, 2); // swap first and last assert_eq!(matrix.rows(), 3); assert_eq!(matrix.cols(), 3); @@ -1193,7 +1227,7 @@ mod tests { assert_eq!(matrix[(1, 2)], 2); assert_eq!(matrix[(2, 2)], 3); - // Swap the same column (should do nothing) + // swap the same column (should do nothing) let original_data = matrix.data().to_vec(); matrix.swap_columns(1, 1); assert_eq!(matrix.data(), &original_data); // Data should be identical @@ -1205,22 +1239,15 @@ mod tests { #[test] #[should_panic(expected = "column index c2=3 out of bounds for 3 columns")] fn test_swap_columns_out_of_bounds() { - let mut matrix = static_test_matrix(); // 3x3 + let mut matrix = static_test_matrix(); matrix.swap_columns(0, 3); } #[test] fn test_delete_column() { - let mut matrix = static_test_matrix_2x4(); // 2x4 - // 1 3 5 7 - // 2 4 6 8 - + let mut matrix = static_test_matrix_2x4(); matrix.delete_column(1); // Delete the second column (index 1) - // Should be: - // 1 5 7 - // 2 6 8 - assert_eq!(matrix.rows(), 2); assert_eq!(matrix.cols(), 3); assert_eq!(matrix[(0, 0)], 1); @@ -1230,23 +1257,17 @@ mod tests { assert_eq!(matrix[(0, 2)], 7); assert_eq!(matrix[(1, 2)], 8); - // Check underlying data (column-major) + // check underlying data assert_eq!(matrix.data(), &[1, 2, 5, 6, 7, 8]); // Delete the first column matrix.delete_column(0); - // Should be: - // 5 7 - // 6 8 assert_eq!(matrix.rows(), 2); assert_eq!(matrix.cols(), 2); assert_eq!(matrix.data(), &[5, 6, 7, 8]); // Delete the last column matrix.delete_column(1); - // Should be: - // 5 - // 6 assert_eq!(matrix.rows(), 2); assert_eq!(matrix.cols(), 1); assert_eq!(matrix.data(), &[5, 6]); @@ -1262,22 +1283,15 @@ mod tests { #[test] #[should_panic(expected = "column index 4 out of bounds for 4 columns")] fn test_delete_column_out_of_bounds() { - let mut matrix = static_test_matrix_2x4(); // 2x4 + let mut matrix = static_test_matrix_2x4(); matrix.delete_column(4); } #[test] fn test_delete_row() { - let mut matrix = static_test_matrix(); // 3x3 - // 1 4 7 - // 2 5 8 - // 3 6 9 + let mut matrix = static_test_matrix(); - matrix.delete_row(1); // Delete the second row (index 1) - - // Should be: - // 1 4 7 - // 3 6 9 + matrix.delete_row(1); // Delete the second row assert_eq!(matrix.rows(), 2); assert_eq!(matrix.cols(), 3); @@ -1288,9 +1302,7 @@ mod tests { assert_eq!(matrix[(0, 2)], 7); assert_eq!(matrix[(1, 2)], 9); - // Check underlying data (column-major) - // Original: [1, 2, 3, 4, 5, 6, 7, 8, 9] - // Delete row 1: [1, 3, 4, 6, 7, 9] + // check underlying data (column-major) assert_eq!(matrix.data(), &[1, 3, 4, 6, 7, 9]); // Delete the first row @@ -1312,22 +1324,15 @@ mod tests { #[test] #[should_panic(expected = "row index 3 out of bounds for 3 rows")] fn test_delete_row_out_of_bounds() { - let mut matrix = static_test_matrix(); // 3x3 + let mut matrix = static_test_matrix(); matrix.delete_row(3); } #[test] fn test_add_column() { - let mut matrix = static_test_matrix_2x4(); // 2x4 - // 1 3 5 7 - // 2 4 6 8 - + let mut matrix = static_test_matrix_2x4(); let new_col = vec![9, 10]; - matrix.add_column(2, new_col); // Add at index 2 - - // Should be: - // 1 3 9 5 7 - // 2 4 10 6 8 + matrix.add_column(2, new_col); assert_eq!(matrix.rows(), 2); assert_eq!(matrix.cols(), 5); @@ -1342,17 +1347,12 @@ mod tests { assert_eq!(matrix[(0, 4)], 7); // Shifted assert_eq!(matrix[(1, 4)], 8); - // Check underlying data (column-major) - // Original: [1, 2, 3, 4, 5, 6, 7, 8] - // Add [9, 10] at index 2: [1, 2, 3, 4, 9, 10, 5, 6, 7, 8] + // Check underlying data assert_eq!(matrix.data(), &[1, 2, 3, 4, 9, 10, 5, 6, 7, 8]); // Add a column at the beginning let new_col_start = vec![11, 12]; matrix.add_column(0, new_col_start); - // Should be: - // 11 1 3 9 5 7 - // 12 2 4 10 6 8 assert_eq!(matrix.rows(), 2); assert_eq!(matrix.cols(), 6); assert_eq!(matrix[(0, 0)], 11); @@ -1362,9 +1362,6 @@ mod tests { // Add a column at the end let new_col_end = vec![13, 14]; matrix.add_column(6, new_col_end); - // Should be: - // 11 1 3 9 5 7 13 - // 12 2 4 10 6 8 14 assert_eq!(matrix.rows(), 2); assert_eq!(matrix.cols(), 7); assert_eq!(matrix[(0, 6)], 13); @@ -1378,7 +1375,7 @@ mod tests { #[test] #[should_panic(expected = "add_column index 5 out of bounds for 4 columns")] fn test_add_column_out_of_bounds() { - let mut matrix = static_test_matrix_2x4(); // 2x4 + let mut matrix = static_test_matrix_2x4(); let new_col = vec![9, 10]; matrix.add_column(5, new_col); // Index 5 is out of bounds for 4 columns } @@ -1386,24 +1383,16 @@ mod tests { #[test] #[should_panic(expected = "column length mismatch")] fn test_add_column_length_mismatch() { - let mut matrix = static_test_matrix_2x4(); // 2x4 (2 rows) + let mut matrix = static_test_matrix_2x4(); let new_col = vec![9, 10, 11]; // Wrong length matrix.add_column(0, new_col); } #[test] fn test_add_row() { - let mut matrix = static_test_matrix_2x4(); // 2x4 - // 1 3 5 7 - // 2 4 6 8 - + let mut matrix = static_test_matrix_2x4(); let new_row = vec![9, 10, 11, 12]; - matrix.add_row(1, new_row); // Add at index 1 - - // Should be: - // 1 3 5 7 - // 9 10 11 12 - // 2 4 6 8 + matrix.add_row(1, new_row); assert_eq!(matrix.rows(), 3); assert_eq!(matrix.cols(), 4); @@ -1422,23 +1411,11 @@ mod tests { assert_eq!(matrix[(2, 3)], 8); // Check underlying data (column-major) - // Original: [1, 2, 3, 4, 5, 6, 7, 8] (rows 0, 1) - // Add [9, 10, 11, 12] at index 1 (new row will be index 1, original row 1 becomes index 2) - // Col 0: [1, 9, 2] - // Col 1: [3, 10, 4] - // Col 2: [5, 11, 6] - // Col 3: [7, 12, 8] - // Data: [1, 9, 2, 3, 10, 4, 5, 11, 6, 7, 12, 8] assert_eq!(matrix.data(), &[1, 9, 2, 3, 10, 4, 5, 11, 6, 7, 12, 8]); // Add a row at the beginning let new_row_start = vec![13, 14, 15, 16]; matrix.add_row(0, new_row_start); - // Should be: - // 13 14 15 16 - // 1 3 5 7 - // 9 10 11 12 - // 2 4 6 8 assert_eq!(matrix.rows(), 4); assert_eq!(matrix.cols(), 4); assert_eq!(matrix[(0, 0)], 13); @@ -1452,12 +1429,6 @@ mod tests { // Add a row at the end let new_row_end = vec![17, 18, 19, 20]; matrix.add_row(4, new_row_end); - // Should be: - // 13 14 15 16 - // 1 3 5 7 - // 9 10 11 12 - // 2 4 6 8 - // 17 18 19 20 assert_eq!(matrix.rows(), 5); assert_eq!(matrix.cols(), 4); assert_eq!(matrix[(4, 0)], 17); @@ -1467,7 +1438,7 @@ mod tests { #[test] #[should_panic(expected = "add_row index 3 out of bounds for 2 rows")] fn test_add_row_out_of_bounds() { - let mut matrix = static_test_matrix_2x4(); // 2x4 + let mut matrix = static_test_matrix_2x4(); let new_row = vec![9, 10, 11, 12]; matrix.add_row(3, new_row); // Index 3 is out of bounds for 2 rows } @@ -1475,26 +1446,22 @@ mod tests { #[test] #[should_panic(expected = "row length mismatch")] fn test_add_row_length_mismatch() { - let mut matrix = static_test_matrix_2x4(); // 2x4 (4 cols) + let mut matrix = static_test_matrix_2x4(); let new_row = vec![9, 10, 11]; // Wrong length matrix.add_row(0, new_row); } #[test] fn test_elementwise_add() { - let matrix1 = static_test_matrix(); // 3x3 - let matrix2 = Matrix::from_vec(vec![9, 8, 7, 6, 5, 4, 3, 2, 1], 3, 3); // 3x3 + let matrix1 = static_test_matrix(); + let matrix2 = Matrix::from_vec(vec![9, 8, 7, 6, 5, 4, 3, 2, 1], 3, 3); let result = &matrix1 + &matrix2; assert_eq!(result.rows(), 3); assert_eq!(result.cols(), 3); - // Expected: - // 1+9 4+6 7+3 => 10 10 10 - // 2+8 5+5 8+2 => 10 10 10 - // 3+7 6+4 9+1 => 10 10 10 - // Column-major data: [10, 10, 10, 10, 10, 10, 10, 10, 10] + // Expect all 10s assert_eq!(result.data(), &[10, 10, 10, 10, 10, 10, 10, 10, 10]); assert_eq!(result[(0, 0)], 10); assert_eq!(result[(1, 1)], 10); @@ -1503,19 +1470,14 @@ mod tests { #[test] fn test_elementwise_sub() { - let matrix1 = static_test_matrix(); // 3x3 - let matrix2 = Matrix::from_vec(vec![1, 1, 1, 2, 2, 2, 3, 3, 3], 3, 3); // 3x3 + let matrix1 = static_test_matrix(); + let matrix2 = Matrix::from_vec(vec![1, 1, 1, 2, 2, 2, 3, 3, 3], 3, 3); let result = &matrix1 - &matrix2; assert_eq!(result.rows(), 3); assert_eq!(result.cols(), 3); - // Expected: - // 1-1 4-2 7-3 => 0 2 4 - // 2-1 5-2 8-3 => 1 3 5 - // 3-1 6-2 9-3 => 2 4 6 - // Column-major data: [0, 1, 2, 2, 3, 4, 4, 5, 6] assert_eq!(result.data(), &[0, 1, 2, 2, 3, 4, 4, 5, 6]); assert_eq!(result[(0, 0)], 0); assert_eq!(result[(1, 1)], 3); @@ -1524,19 +1486,15 @@ mod tests { #[test] fn test_elementwise_mul() { - let matrix1 = static_test_matrix(); // 3x3 - let matrix2 = Matrix::from_vec(vec![1, 2, 3, 1, 2, 3, 1, 2, 3], 3, 3); // 3x3 + let matrix1 = static_test_matrix(); + let matrix2 = Matrix::from_vec(vec![1, 2, 3, 1, 2, 3, 1, 2, 3], 3, 3); let result = &matrix1 * &matrix2; assert_eq!(result.rows(), 3); assert_eq!(result.cols(), 3); - // Expected: - // 1*1 4*1 7*1 => 1 4 7 - // 2*2 5*2 8*2 => 4 10 16 - // 3*3 6*3 9*3 => 9 18 27 - // Column-major data: [1, 4, 9, 4, 10, 18, 7, 16, 27] + // Expected assert_eq!(result.data(), &[1, 4, 9, 4, 10, 18, 7, 16, 27]); assert_eq!(result[(0, 0)], 1); assert_eq!(result[(1, 1)], 10); @@ -1545,19 +1503,14 @@ mod tests { #[test] fn test_elementwise_div() { - let matrix1 = static_test_matrix(); // 3x3 - let matrix2 = Matrix::from_vec(vec![1, 1, 1, 2, 2, 2, 7, 8, 9], 3, 3); // 3x3 + let matrix1 = static_test_matrix(); + let matrix2 = Matrix::from_vec(vec![1, 1, 1, 2, 2, 2, 7, 8, 9], 3, 3); let result = &matrix1 / &matrix2; // Integer division assert_eq!(result.rows(), 3); assert_eq!(result.cols(), 3); - // Expected: - // 1/1 4/2 7/7 => 1 2 1 - // 2/1 5/2 8/8 => 2 2 1 (integer division) - // 3/1 6/2 9/9 => 3 3 1 - // Column-major data: [1, 2, 3, 2, 2, 3, 1, 1, 1] assert_eq!(result.data(), &[1, 2, 3, 2, 2, 3, 1, 1, 1]); assert_eq!(result[(0, 0)], 1); assert_eq!(result[(1, 1)], 2); @@ -1567,34 +1520,26 @@ mod tests { #[test] #[should_panic(expected = "Row count mismatch: left has 3 rows, right has 2 rows")] fn test_elementwise_op_row_mismatch() { - let matrix1 = static_test_matrix(); // 3x3 - let matrix2 = static_test_matrix_2x4(); // 2x4 + let matrix1 = static_test_matrix(); + let matrix2 = static_test_matrix_2x4(); let _ = &matrix1 + &matrix2; // Should panic } #[test] #[should_panic(expected = "Row count mismatch: left has 3 rows, right has 2 ro")] fn test_elementwise_op_col_mismatch() { - let matrix1 = static_test_matrix(); // 3x3 - let matrix2 = static_test_matrix_2x4(); // 2x4 + let matrix1 = static_test_matrix(); + let matrix2 = static_test_matrix_2x4(); let _ = &matrix1 * &matrix2; // Should panic } #[test] fn test_bitwise_and() { - let data1 = vec![true, false, true, false, true, false]; // 2x3 - let data2 = vec![true, true, false, false, true, true]; // 2x3 + let data1 = vec![true, false, true, false, true, false]; + let data2 = vec![true, true, false, false, true, true]; let matrix1 = BoolMatrix::from_vec(data1, 2, 3); let matrix2 = BoolMatrix::from_vec(data2, 2, 3); - // Expected column-major results: - // T & T = T - // F & T = F - // T & F = F - // F & F = F - // T & T = T - // F & T = F - // Data: [T, F, F, F, T, F] let expected_data = vec![true, false, false, false, true, false]; let expected_matrix = BoolMatrix::from_vec(expected_data, 2, 3); @@ -1604,19 +1549,11 @@ mod tests { #[test] fn test_bitwise_or() { - let data1 = vec![true, false, true, false, true, false]; // 2x3 - let data2 = vec![true, true, false, false, true, true]; // 2x3 + let data1 = vec![true, false, true, false, true, false]; + let data2 = vec![true, true, false, false, true, true]; let matrix1 = BoolMatrix::from_vec(data1, 2, 3); let matrix2 = BoolMatrix::from_vec(data2, 2, 3); - // Expected column-major results: - // T | T = T - // F | T = T - // T | F = T - // F | F = F - // T | T = T - // F | T = T - // Data: [T, T, T, F, T, T] let expected_data = vec![true, true, true, false, true, true]; let expected_matrix = BoolMatrix::from_vec(expected_data, 2, 3); @@ -1626,19 +1563,11 @@ mod tests { #[test] fn test_bitwise_xor() { - let data1 = vec![true, false, true, false, true, false]; // 2x3 - let data2 = vec![true, true, false, false, true, true]; // 2x3 + let data1 = vec![true, false, true, false, true, false]; + let data2 = vec![true, true, false, false, true, true]; let matrix1 = BoolMatrix::from_vec(data1, 2, 3); let matrix2 = BoolMatrix::from_vec(data2, 2, 3); - // Expected column-major results: - // T ^ T = F - // F ^ T = T - // T ^ F = T - // F ^ F = F - // T ^ T = F - // F ^ T = T - // Data: [F, T, T, F, F, T] let expected_data = vec![false, true, true, false, false, true]; let expected_matrix = BoolMatrix::from_vec(expected_data, 2, 3); @@ -1648,17 +1577,9 @@ mod tests { #[test] fn test_bitwise_not() { - let data = vec![true, false, true, false, true, false]; // 2x3 + let data = vec![true, false, true, false, true, false]; let matrix = BoolMatrix::from_vec(data, 2, 3); - // Expected column-major results: - // !T = F - // !F = T - // !T = F - // !F = T - // !T = F - // !F = T - // Data: [F, T, F, T, F, T] let expected_data = vec![false, true, false, true, false, true]; let expected_matrix = BoolMatrix::from_vec(expected_data, 2, 3); @@ -1669,8 +1590,8 @@ mod tests { #[test] #[should_panic(expected = "Column count mismatch: left has 2 columns, right has 3 columns")] fn test_bitwise_op_row_mismatch() { - let data1 = vec![true, false, true, false]; // 2x2 - let data2 = vec![true, true, false, false, true, true]; // 2x3 + let data1 = vec![true, false, true, false]; + let data2 = vec![true, true, false, false, true, true]; let matrix1 = BoolMatrix::from_vec(data1, 2, 2); let matrix2 = BoolMatrix::from_vec(data2, 2, 3); let _ = &matrix1 & &matrix2; // Should panic @@ -1679,8 +1600,8 @@ mod tests { #[test] #[should_panic(expected = "Column count mismatch: left has 2 columns, right has 3 columns")] fn test_bitwise_op_col_mismatch() { - let data1 = vec![true, false, true, false]; // 2x2 - let data2 = vec![true, true, false, false, true, true]; // 2x3 + let data1 = vec![true, false, true, false]; + let data2 = vec![true, true, false, false, true, true]; let matrix1 = BoolMatrix::from_vec(data1, 2, 2); let matrix2 = BoolMatrix::from_vec(data2, 2, 3); let _ = &matrix1 | &matrix2; // Should panic @@ -1695,7 +1616,7 @@ mod tests { "c".to_string(), "d".to_string(), ]; - let matrix = StringMatrix::from_vec(data.clone(), 2, 2); // 2x2 + let matrix = StringMatrix::from_vec(data.clone(), 2, 2); assert_eq!(matrix[(0, 0)], "a".to_string()); assert_eq!(matrix[(1, 0)], "b".to_string()); @@ -1707,12 +1628,10 @@ mod tests { matrix[(0, 0)] = "hello".to_string(); assert_eq!(matrix[(0, 0)], "hello".to_string()); - // Test add_column (requires Clone) + // Test add_column let new_col = vec!["e".to_string(), "f".to_string()]; - matrix.add_column(1, new_col); // Add at index 1 - // Should be: - // hello c d - // b e f + matrix.add_column(1, new_col); + assert_eq!(matrix.rows(), 2); assert_eq!(matrix.cols(), 3); assert_eq!(matrix[(0, 0)], "hello".to_string()); @@ -1724,11 +1643,8 @@ mod tests { // Test add_row (requires Clone) let new_row = vec!["g".to_string(), "h".to_string(), "i".to_string()]; - matrix.add_row(0, new_row); // Add at index 0 - // Should be: - // g h i - // hello e c - // b f d + matrix.add_row(0, new_row); + assert_eq!(matrix.rows(), 3); assert_eq!(matrix.cols(), 3); assert_eq!(matrix[(0, 0)], "g".to_string()); @@ -1740,8 +1656,8 @@ mod tests { #[test] fn test_float_matrix_ops() { - let data1 = vec![1.0, 2.0, 3.0, 4.0]; // 2x2 - let data2 = vec![0.5, 1.5, 2.5, 3.5]; // 2x2 + let data1 = vec![1.0, 2.0, 3.0, 4.0]; + let data2 = vec![0.5, 1.5, 2.5, 3.5]; let matrix1 = FloatMatrix::from_vec(data1, 2, 2); let matrix2 = FloatMatrix::from_vec(data2, 2, 2); @@ -1750,17 +1666,13 @@ mod tests { let prod = &matrix1 * &matrix2; let div = &matrix1 / &matrix2; - // Check sums (col-major): [1.5, 3.5, 5.5, 7.5] assert_eq!(sum.data(), &[1.5, 3.5, 5.5, 7.5]); - // Check diffs (col-major): [0.5, 0.5, 0.5, 0.5] assert_eq!(diff.data(), &[0.5, 0.5, 0.5, 0.5]); - // Check prods (col-major): [0.5, 3.0, 7.5, 14.0] assert_eq!(prod.data(), &[0.5, 3.0, 7.5, 14.0]); // Check divs (col-major): [2.0, 1.333..., 1.2, 1.14...] - // Using element access for more specific checks on floating point results assert_eq!(div.rows(), 2); assert_eq!(div.cols(), 2); assert!((div[(0, 0)] - 1.0 / 0.5).abs() < 1e-9); // 2.0 @@ -1785,20 +1697,18 @@ mod tests { // Perform the swap directly on the matrix matrix.swap_columns(0, 2); // Swap column 0 and column 2 - // --- Assertions --- - - // 1. Verify the dimensions are unchanged + // Verify the dimensions are unchanged assert_eq!(matrix.rows(), 3, "Matrix rows should remain unchanged"); assert_eq!(matrix.cols(), 3, "Matrix cols should remain unchanged"); - // 2. Verify the column that was NOT swapped is unchanged + // Verify the column that was NOT swapped is unchanged assert_eq!( matrix.column(1), initial_col1_data.as_slice(), // Comparing slice to slice "Column 1 data should be unchanged" ); - // 3. Verify the data swap occurred correctly using the COLUMN ACCESSOR + // Verify the data swap occurred correctly using the COLUMN ACCESSOR // The data originally at index 0 should now be at index 2 assert_eq!( matrix.column(2), @@ -1812,16 +1722,14 @@ mod tests { "Column 0 should now contain the original data from column 2" ); - // 4. (Optional but useful) Verify the underlying raw data vector - // Original data: [1, 2, 3, 4, 5, 6, 7, 8, 9] - // Expected data after swapping col 0 and col 2: [7, 8, 9, 4, 5, 6, 1, 2, 3] + // Verify the underlying raw data vector assert_eq!( matrix.data(), &[7, 8, 9, 4, 5, 6, 1, 2, 3], "Underlying data vector is incorrect after swap" ); - // 5. Test swapping with self (should be a no-op) + // Test swapping with self (should be a no-op) let state_before_self_swap = matrix.clone(); matrix.swap_columns(1, 1); assert_eq!( @@ -1829,7 +1737,7 @@ mod tests { "Swapping a column with itself should not change the matrix" ); - // 6. Test swapping adjacent columns + // Test swapping adjacent columns let mut matrix2 = create_test_matrix_i32(); let initial_col0_data_m2 = matrix2.column(0).to_vec(); let initial_col1_data_m2 = matrix2.column(1).to_vec(); diff --git a/src/matrix/seriesops.rs b/src/matrix/seriesops.rs index bc34a1f..65702e9 100644 --- a/src/matrix/seriesops.rs +++ b/src/matrix/seriesops.rs @@ -12,6 +12,9 @@ pub trait SeriesOps { where F: FnMut(&[f64]) -> U; + fn matrix_mul(&self, other: &Self) -> FloatMatrix; + fn dot(&self, other: &Self) -> FloatMatrix; + fn sum_vertical(&self) -> Vec; fn sum_horizontal(&self) -> Vec; @@ -139,11 +142,37 @@ impl SeriesOps for FloatMatrix { let data = self.data().iter().map(|v| v.is_nan()).collect(); BoolMatrix::from_vec(data, self.rows(), self.cols()) } -} + fn matrix_mul(&self, other: &Self) -> FloatMatrix { + let (m, n) = (self.rows(), self.cols()); + let (n2, p) = (other.rows(), other.cols()); + assert_eq!( + n, n2, + "Cannot multiply: left is {}x{}, right is {}x{}", + m, n, n2, p + ); + + // Column-major addressing: element (row i, col j) lives at j * m + i + let mut data = vec![0.0; m * p]; + for i in 0..m { + for j in 0..p { + let mut sum = 0.0; + for k in 0..n { + sum += self[(i, k)] * other[(k, j)]; + } + data[j * m + i] = sum; // <-- fixed index + } + } + FloatMatrix::from_vec(data, m, p) + } + fn dot(&self, other: &Self) -> FloatMatrix { + self.matrix_mul(other) + } +} #[cfg(test)] mod tests { + use super::*; // Helper function to create a FloatMatrix for SeriesOps testing @@ -156,6 +185,22 @@ mod tests { FloatMatrix::from_vec(data, 3, 3) } + fn create_float_test_matrix_4x4() -> FloatMatrix { + // 4x4 matrix (column-major) with some NaNs + // 1.0 5.0 9.0 13.0 + // 2.0 NaN 10.0 NaN + // 3.0 6.0 NaN 14.0 + // NaN 7.0 11.0 NaN + // first make array with 16 elements + FloatMatrix::from_vec( + (0..16) + .map(|i| if i % 5 == 0 { f64::NAN } else { i as f64 }) + .collect(), + 4, + 4, + ) + } + // --- Tests for SeriesOps (FloatMatrix) --- #[test] @@ -256,6 +301,48 @@ mod tests { assert_eq!(matrix.is_nan(), expected_matrix); } + #[test] + fn test_series_ops_matrix_mul() { + let a = FloatMatrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2); // 2x2 matrix + let b = FloatMatrix::from_vec(vec![5.0, 6.0, 7.0, 8.0], 2, 2); // 2x2 matrix + // result should be: 23, 34, 31, 46 + let expected = FloatMatrix::from_vec(vec![23.0, 34.0, 31.0, 46.0], 2, 2); + assert_eq!(a.matrix_mul(&b), expected); + + assert_eq!(a.dot(&b), a.matrix_mul(&b)); // dot should be the same as matrix_mul for FloatMatrix + } + #[test] + fn test_series_ops_matrix_mul_with_nans() { + let a = create_float_test_matrix(); // 3x3 matrix with some NaNs + let b = create_float_test_matrix(); // 3x3 matrix with some NaNs + + let mut result_vec = Vec::new(); + result_vec.push(30.0); + for _ in 1..9 { + result_vec.push(f64::NAN); + } + let expected = FloatMatrix::from_vec(result_vec, 3, 3); + + let result = a.matrix_mul(&b); + + assert_eq!(result.is_nan(), expected.is_nan()); + assert_eq!( + result.count_nan_horizontal(), + expected.count_nan_horizontal() + ); + assert_eq!(result.count_nan_vertical(), expected.count_nan_vertical()); + assert_eq!(result[(0, 0)], expected[(0, 0)]); + } + + #[test] + #[should_panic(expected = "Cannot multiply: left is 3x3, right is 4x4")] + fn test_series_ops_matrix_mul_errors() { + let a = create_float_test_matrix(); + let b = create_float_test_matrix_4x4(); + + a.dot(&b); // This should panic due to dimension mismatch + } + // --- Edge Cases for SeriesOps --- #[test] @@ -286,7 +373,7 @@ mod tests { #[test] fn test_series_ops_1xn_matrix() { let matrix = FloatMatrix::from_vec(vec![1.0, f64::NAN, 3.0, 4.0], 1, 4); // 1 row, 4 cols - // Data: [1.0, NaN, 3.0, 4.0] + // Data: [1.0, NaN, 3.0, 4.0] // Vertical (sums/prods/counts per column - each col is just one element) assert_eq!(matrix.sum_vertical(), vec![1.0, 0.0, 3.0, 4.0]); // NaN sum is 0 @@ -316,7 +403,7 @@ mod tests { #[test] fn test_series_ops_nx1_matrix() { let matrix = FloatMatrix::from_vec(vec![1.0, 2.0, f64::NAN, 4.0], 4, 1); // 4 rows, 1 col - // Data: [1.0, 2.0, NaN, 4.0] + // Data: [1.0, 2.0, NaN, 4.0] // Vertical (sums/prods/counts for the single column) // Col 0: 1.0 + 2.0 + NaN + 4.0 = 7.0