mirror of
https://github.com/Magnus167/rustframe.git
synced 2025-08-20 04:00:01 +00:00
Implement matrix multiplication and dot product methods for FloatMatrix in SeriesOps trait
This commit is contained in:
parent
3e279b8138
commit
75f194b8c9
@ -12,6 +12,9 @@ pub trait SeriesOps {
|
|||||||
where
|
where
|
||||||
F: FnMut(&[f64]) -> U;
|
F: FnMut(&[f64]) -> U;
|
||||||
|
|
||||||
|
fn matrix_mul(&self, other: &Self) -> FloatMatrix;
|
||||||
|
fn dot(&self, other: &Self) -> FloatMatrix;
|
||||||
|
|
||||||
fn sum_vertical(&self) -> Vec<f64>;
|
fn sum_vertical(&self) -> Vec<f64>;
|
||||||
fn sum_horizontal(&self) -> Vec<f64>;
|
fn sum_horizontal(&self) -> Vec<f64>;
|
||||||
|
|
||||||
@ -139,11 +142,37 @@ impl SeriesOps for FloatMatrix {
|
|||||||
let data = self.data().iter().map(|v| v.is_nan()).collect();
|
let data = self.data().iter().map(|v| v.is_nan()).collect();
|
||||||
BoolMatrix::from_vec(data, self.rows(), self.cols())
|
BoolMatrix::from_vec(data, self.rows(), self.cols())
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
|
fn matrix_mul(&self, other: &Self) -> FloatMatrix {
|
||||||
|
let (m, n) = (self.rows(), self.cols());
|
||||||
|
let (n2, p) = (other.rows(), other.cols());
|
||||||
|
assert_eq!(
|
||||||
|
n, n2,
|
||||||
|
"Cannot multiply: left is {}x{}, right is {}x{}",
|
||||||
|
m, n, n2, p
|
||||||
|
);
|
||||||
|
|
||||||
|
// Column-major addressing: element (row i, col j) lives at j * m + i
|
||||||
|
let mut data = vec![0.0; m * p];
|
||||||
|
for i in 0..m {
|
||||||
|
for j in 0..p {
|
||||||
|
let mut sum = 0.0;
|
||||||
|
for k in 0..n {
|
||||||
|
sum += self[(i, k)] * other[(k, j)];
|
||||||
|
}
|
||||||
|
data[j * m + i] = sum; // <-- fixed index
|
||||||
|
}
|
||||||
|
}
|
||||||
|
FloatMatrix::from_vec(data, m, p)
|
||||||
|
}
|
||||||
|
fn dot(&self, other: &Self) -> FloatMatrix {
|
||||||
|
self.matrix_mul(other)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
// Helper function to create a FloatMatrix for SeriesOps testing
|
// Helper function to create a FloatMatrix for SeriesOps testing
|
||||||
@ -256,6 +285,38 @@ mod tests {
|
|||||||
assert_eq!(matrix.is_nan(), expected_matrix);
|
assert_eq!(matrix.is_nan(), expected_matrix);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_series_ops_matrix_mul() {
|
||||||
|
let a = FloatMatrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2); // 2x2 matrix
|
||||||
|
let b = FloatMatrix::from_vec(vec![5.0, 6.0, 7.0, 8.0], 2, 2); // 2x2 matrix
|
||||||
|
// result should be: 23, 34, 31, 46
|
||||||
|
let expected = FloatMatrix::from_vec(vec![23.0, 34.0, 31.0, 46.0], 2, 2);
|
||||||
|
assert_eq!(a.matrix_mul(&b), expected);
|
||||||
|
|
||||||
|
assert_eq!(a.dot(&b), a.matrix_mul(&b)); // dot should be the same as matrix_mul for FloatMatrix
|
||||||
|
}
|
||||||
|
#[test]
|
||||||
|
fn test_series_ops_matrix_mul_with_nans() {
|
||||||
|
let a = create_float_test_matrix(); // 3x3 matrix with some NaNs
|
||||||
|
let b = create_float_test_matrix(); // 3x3 matrix with some NaNs
|
||||||
|
|
||||||
|
let mut result_vec = Vec::new();
|
||||||
|
result_vec.push(30.0);
|
||||||
|
for _ in 1..9 {
|
||||||
|
result_vec.push(f64::NAN);
|
||||||
|
}
|
||||||
|
let expected = FloatMatrix::from_vec(result_vec, 3, 3);
|
||||||
|
|
||||||
|
let result = a.matrix_mul(&b);
|
||||||
|
|
||||||
|
assert_eq!(result.is_nan(), expected.is_nan());
|
||||||
|
assert_eq!(
|
||||||
|
result.count_nan_horizontal(),
|
||||||
|
expected.count_nan_horizontal()
|
||||||
|
);
|
||||||
|
assert_eq!(result.count_nan_vertical(), expected.count_nan_vertical());
|
||||||
|
assert_eq!(result[(0, 0)], expected[(0, 0)]);
|
||||||
|
}
|
||||||
// --- Edge Cases for SeriesOps ---
|
// --- Edge Cases for SeriesOps ---
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -286,7 +347,7 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_series_ops_1xn_matrix() {
|
fn test_series_ops_1xn_matrix() {
|
||||||
let matrix = FloatMatrix::from_vec(vec![1.0, f64::NAN, 3.0, 4.0], 1, 4); // 1 row, 4 cols
|
let matrix = FloatMatrix::from_vec(vec![1.0, f64::NAN, 3.0, 4.0], 1, 4); // 1 row, 4 cols
|
||||||
// Data: [1.0, NaN, 3.0, 4.0]
|
// Data: [1.0, NaN, 3.0, 4.0]
|
||||||
|
|
||||||
// Vertical (sums/prods/counts per column - each col is just one element)
|
// Vertical (sums/prods/counts per column - each col is just one element)
|
||||||
assert_eq!(matrix.sum_vertical(), vec![1.0, 0.0, 3.0, 4.0]); // NaN sum is 0
|
assert_eq!(matrix.sum_vertical(), vec![1.0, 0.0, 3.0, 4.0]); // NaN sum is 0
|
||||||
@ -316,7 +377,7 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_series_ops_nx1_matrix() {
|
fn test_series_ops_nx1_matrix() {
|
||||||
let matrix = FloatMatrix::from_vec(vec![1.0, 2.0, f64::NAN, 4.0], 4, 1); // 4 rows, 1 col
|
let matrix = FloatMatrix::from_vec(vec![1.0, 2.0, f64::NAN, 4.0], 4, 1); // 4 rows, 1 col
|
||||||
// Data: [1.0, 2.0, NaN, 4.0]
|
// Data: [1.0, 2.0, NaN, 4.0]
|
||||||
|
|
||||||
// Vertical (sums/prods/counts for the single column)
|
// Vertical (sums/prods/counts for the single column)
|
||||||
// Col 0: 1.0 + 2.0 + NaN + 4.0 = 7.0
|
// Col 0: 1.0 + 2.0 + NaN + 4.0 = 7.0
|
||||||
|
Loading…
x
Reference in New Issue
Block a user