Implement matrix multiplication and dot product methods for FloatMatrix in SeriesOps trait

This commit is contained in:
Palash Tyagi 2025-07-05 20:44:16 +01:00
parent 3e279b8138
commit 75f194b8c9

View File

@ -12,6 +12,9 @@ pub trait SeriesOps {
where where
F: FnMut(&[f64]) -> U; F: FnMut(&[f64]) -> U;
fn matrix_mul(&self, other: &Self) -> FloatMatrix;
fn dot(&self, other: &Self) -> FloatMatrix;
fn sum_vertical(&self) -> Vec<f64>; fn sum_vertical(&self) -> Vec<f64>;
fn sum_horizontal(&self) -> Vec<f64>; fn sum_horizontal(&self) -> Vec<f64>;
@ -139,11 +142,37 @@ impl SeriesOps for FloatMatrix {
let data = self.data().iter().map(|v| v.is_nan()).collect(); let data = self.data().iter().map(|v| v.is_nan()).collect();
BoolMatrix::from_vec(data, self.rows(), self.cols()) BoolMatrix::from_vec(data, self.rows(), self.cols())
} }
}
fn matrix_mul(&self, other: &Self) -> FloatMatrix {
let (m, n) = (self.rows(), self.cols());
let (n2, p) = (other.rows(), other.cols());
assert_eq!(
n, n2,
"Cannot multiply: left is {}x{}, right is {}x{}",
m, n, n2, p
);
// Column-major addressing: element (row i, col j) lives at j * m + i
let mut data = vec![0.0; m * p];
for i in 0..m {
for j in 0..p {
let mut sum = 0.0;
for k in 0..n {
sum += self[(i, k)] * other[(k, j)];
}
data[j * m + i] = sum; // <-- fixed index
}
}
FloatMatrix::from_vec(data, m, p)
}
fn dot(&self, other: &Self) -> FloatMatrix {
self.matrix_mul(other)
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
// Helper function to create a FloatMatrix for SeriesOps testing // Helper function to create a FloatMatrix for SeriesOps testing
@ -256,6 +285,38 @@ mod tests {
assert_eq!(matrix.is_nan(), expected_matrix); assert_eq!(matrix.is_nan(), expected_matrix);
} }
#[test]
fn test_series_ops_matrix_mul() {
let a = FloatMatrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2); // 2x2 matrix
let b = FloatMatrix::from_vec(vec![5.0, 6.0, 7.0, 8.0], 2, 2); // 2x2 matrix
// result should be: 23, 34, 31, 46
let expected = FloatMatrix::from_vec(vec![23.0, 34.0, 31.0, 46.0], 2, 2);
assert_eq!(a.matrix_mul(&b), expected);
assert_eq!(a.dot(&b), a.matrix_mul(&b)); // dot should be the same as matrix_mul for FloatMatrix
}
#[test]
fn test_series_ops_matrix_mul_with_nans() {
let a = create_float_test_matrix(); // 3x3 matrix with some NaNs
let b = create_float_test_matrix(); // 3x3 matrix with some NaNs
let mut result_vec = Vec::new();
result_vec.push(30.0);
for _ in 1..9 {
result_vec.push(f64::NAN);
}
let expected = FloatMatrix::from_vec(result_vec, 3, 3);
let result = a.matrix_mul(&b);
assert_eq!(result.is_nan(), expected.is_nan());
assert_eq!(
result.count_nan_horizontal(),
expected.count_nan_horizontal()
);
assert_eq!(result.count_nan_vertical(), expected.count_nan_vertical());
assert_eq!(result[(0, 0)], expected[(0, 0)]);
}
// --- Edge Cases for SeriesOps --- // --- Edge Cases for SeriesOps ---
#[test] #[test]
@ -286,7 +347,7 @@ mod tests {
#[test] #[test]
fn test_series_ops_1xn_matrix() { fn test_series_ops_1xn_matrix() {
let matrix = FloatMatrix::from_vec(vec![1.0, f64::NAN, 3.0, 4.0], 1, 4); // 1 row, 4 cols let matrix = FloatMatrix::from_vec(vec![1.0, f64::NAN, 3.0, 4.0], 1, 4); // 1 row, 4 cols
// Data: [1.0, NaN, 3.0, 4.0] // Data: [1.0, NaN, 3.0, 4.0]
// Vertical (sums/prods/counts per column - each col is just one element) // Vertical (sums/prods/counts per column - each col is just one element)
assert_eq!(matrix.sum_vertical(), vec![1.0, 0.0, 3.0, 4.0]); // NaN sum is 0 assert_eq!(matrix.sum_vertical(), vec![1.0, 0.0, 3.0, 4.0]); // NaN sum is 0
@ -316,7 +377,7 @@ mod tests {
#[test] #[test]
fn test_series_ops_nx1_matrix() { fn test_series_ops_nx1_matrix() {
let matrix = FloatMatrix::from_vec(vec![1.0, 2.0, f64::NAN, 4.0], 4, 1); // 4 rows, 1 col let matrix = FloatMatrix::from_vec(vec![1.0, 2.0, f64::NAN, 4.0], 4, 1); // 4 rows, 1 col
// Data: [1.0, 2.0, NaN, 4.0] // Data: [1.0, 2.0, NaN, 4.0]
// Vertical (sums/prods/counts for the single column) // Vertical (sums/prods/counts for the single column)
// Col 0: 1.0 + 2.0 + NaN + 4.0 = 7.0 // Col 0: 1.0 + 2.0 + NaN + 4.0 = 7.0