From 75f194b8c919211e65e3da037393a54b303fb144 Mon Sep 17 00:00:00 2001 From: Palash Tyagi <23239946+Magnus167@users.noreply.github.com> Date: Sat, 5 Jul 2025 20:44:16 +0100 Subject: [PATCH] Implement matrix multiplication and dot product methods for FloatMatrix in SeriesOps trait --- src/matrix/seriesops.rs | 67 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 64 insertions(+), 3 deletions(-) diff --git a/src/matrix/seriesops.rs b/src/matrix/seriesops.rs index bc34a1f..60e7f30 100644 --- a/src/matrix/seriesops.rs +++ b/src/matrix/seriesops.rs @@ -12,6 +12,9 @@ pub trait SeriesOps { where F: FnMut(&[f64]) -> U; + fn matrix_mul(&self, other: &Self) -> FloatMatrix; + fn dot(&self, other: &Self) -> FloatMatrix; + fn sum_vertical(&self) -> Vec; fn sum_horizontal(&self) -> Vec; @@ -139,11 +142,37 @@ impl SeriesOps for FloatMatrix { let data = self.data().iter().map(|v| v.is_nan()).collect(); BoolMatrix::from_vec(data, self.rows(), self.cols()) } -} + fn matrix_mul(&self, other: &Self) -> FloatMatrix { + let (m, n) = (self.rows(), self.cols()); + let (n2, p) = (other.rows(), other.cols()); + assert_eq!( + n, n2, + "Cannot multiply: left is {}x{}, right is {}x{}", + m, n, n2, p + ); + + // Column-major addressing: element (row i, col j) lives at j * m + i + let mut data = vec![0.0; m * p]; + for i in 0..m { + for j in 0..p { + let mut sum = 0.0; + for k in 0..n { + sum += self[(i, k)] * other[(k, j)]; + } + data[j * m + i] = sum; // <-- fixed index + } + } + FloatMatrix::from_vec(data, m, p) + } + fn dot(&self, other: &Self) -> FloatMatrix { + self.matrix_mul(other) + } +} #[cfg(test)] mod tests { + use super::*; // Helper function to create a FloatMatrix for SeriesOps testing @@ -256,6 +285,38 @@ mod tests { assert_eq!(matrix.is_nan(), expected_matrix); } + #[test] + fn test_series_ops_matrix_mul() { + let a = FloatMatrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2); // 2x2 matrix + let b = FloatMatrix::from_vec(vec![5.0, 6.0, 7.0, 8.0], 2, 2); // 2x2 matrix + // result should be: 23, 34, 31, 46 + let expected = FloatMatrix::from_vec(vec![23.0, 34.0, 31.0, 46.0], 2, 2); + assert_eq!(a.matrix_mul(&b), expected); + + assert_eq!(a.dot(&b), a.matrix_mul(&b)); // dot should be the same as matrix_mul for FloatMatrix + } + #[test] + fn test_series_ops_matrix_mul_with_nans() { + let a = create_float_test_matrix(); // 3x3 matrix with some NaNs + let b = create_float_test_matrix(); // 3x3 matrix with some NaNs + + let mut result_vec = Vec::new(); + result_vec.push(30.0); + for _ in 1..9 { + result_vec.push(f64::NAN); + } + let expected = FloatMatrix::from_vec(result_vec, 3, 3); + + let result = a.matrix_mul(&b); + + assert_eq!(result.is_nan(), expected.is_nan()); + assert_eq!( + result.count_nan_horizontal(), + expected.count_nan_horizontal() + ); + assert_eq!(result.count_nan_vertical(), expected.count_nan_vertical()); + assert_eq!(result[(0, 0)], expected[(0, 0)]); + } // --- Edge Cases for SeriesOps --- #[test] @@ -286,7 +347,7 @@ mod tests { #[test] fn test_series_ops_1xn_matrix() { let matrix = FloatMatrix::from_vec(vec![1.0, f64::NAN, 3.0, 4.0], 1, 4); // 1 row, 4 cols - // Data: [1.0, NaN, 3.0, 4.0] + // Data: [1.0, NaN, 3.0, 4.0] // Vertical (sums/prods/counts per column - each col is just one element) assert_eq!(matrix.sum_vertical(), vec![1.0, 0.0, 3.0, 4.0]); // NaN sum is 0 @@ -316,7 +377,7 @@ mod tests { #[test] fn test_series_ops_nx1_matrix() { let matrix = FloatMatrix::from_vec(vec![1.0, 2.0, f64::NAN, 4.0], 4, 1); // 4 rows, 1 col - // Data: [1.0, 2.0, NaN, 4.0] + // Data: [1.0, 2.0, NaN, 4.0] // Vertical (sums/prods/counts for the single column) // Col 0: 1.0 + 2.0 + NaN + 4.0 = 7.0