From 10018f7efe8bc0a3b51344036e81158e4c5ba8d5 Mon Sep 17 00:00:00 2001 From: Palash Tyagi <23239946+Magnus167@users.noreply.github.com> Date: Sat, 12 Jul 2025 00:50:14 +0100 Subject: [PATCH] Refactor covariance_matrix to improve mean calculation and add broadcasting for centered data; add tests for vertical and horizontal covariance matrices --- src/compute/stats/correlation.rs | 94 +++++++++++++++++++++++++------- 1 file changed, 73 insertions(+), 21 deletions(-) diff --git a/src/compute/stats/correlation.rs b/src/compute/stats/correlation.rs index 2d24632..0e29764 100644 --- a/src/compute/stats/correlation.rs +++ b/src/compute/stats/correlation.rs @@ -82,16 +82,30 @@ pub fn covariance_horizontal(x: &Matrix) -> Matrix { /// Calculates the covariance matrix of the input data. /// Assumes input `x` is (n_samples, n_features). pub fn covariance_matrix(x: &Matrix, axis: Axis) -> Matrix { - let (n_samples, _n_features) = x.shape(); + let (n_samples, n_features) = x.shape(); - let mean_matrix = match axis { - Axis::Col => mean_vertical(x), // Mean of each feature (column) - Axis::Row => mean_horizontal(x), // Mean of each sample (row) + let centered_data = match axis { + Axis::Col => { + let mean_matrix = mean_vertical(x); // 1 x n_features + x.zip( + &mean_matrix.broadcast_row_to_target_shape(n_samples, n_features), + |val, m| val - m, + ) + } + Axis::Row => { + let mean_matrix = mean_horizontal(x); // n_samples x 1 + // Manually create a matrix by broadcasting the column vector across columns + let mut broadcasted_mean = Matrix::zeros(n_samples, n_features); + for r in 0..n_samples { + let mean_val = mean_matrix.get(r, 0); + for c in 0..n_features { + *broadcasted_mean.get_mut(r, c) = *mean_val; + } + } + x.zip(&broadcasted_mean, |val, m| val - m) + } }; - // Center the data - let centered_data = x.zip(&mean_matrix.broadcast_row_to_target_shape(n_samples, x.cols()), |val, m| val - m); - // Calculate covariance matrix: (X_centered^T * X_centered) / (n_samples - 1) // If x is (n_samples, n_features), then centered_data is (n_samples, n_features) // centered_data.transpose() is (n_features, n_samples) @@ -148,13 +162,7 @@ mod tests { // Expect 2x2 matrix of all 1.0 for i in 0..2 { for j in 0..2 { - assert!( - (cov_mat.get(i, j) - 1.0).abs() < EPS, - "cov_mat[{},{}] = {}", - i, - j, - cov_mat.get(i, j) - ); + assert!((cov_mat.get(i, j) - 1.0).abs() < EPS); } } } @@ -171,14 +179,58 @@ mod tests { // Expect 2x2 matrix of all 0.25 for i in 0..2 { for j in 0..2 { - assert!( - (cov_mat.get(i, j) - 0.25).abs() < EPS, - "cov_mat[{},{}] = {}", - i, - j, - cov_mat.get(i, j) - ); + assert!((cov_mat.get(i, j) - 0.25).abs() < EPS); } } } + + #[test] + fn test_covariance_matrix_vertical() { + // Test with a simple 2x2 matrix + // M = + // 1, 2 + // 3, 4 + // Expected covariance matrix (vertical, i.e., between columns): + // Col1: [1, 3], mean = 2 + // Col2: [2, 4], mean = 3 + // Cov(Col1, Col1) = ((1-2)^2 + (3-2)^2) / (2-1) = (1+1)/1 = 2 + // Cov(Col2, Col2) = ((2-3)^2 + (4-3)^2) / (2-1) = (1+1)/1 = 2 + // Cov(Col1, Col2) = ((1-2)*(2-3) + (3-2)*(4-3)) / (2-1) = ((-1)*(-1) + (1)*(1))/1 = (1+1)/1 = 2 + // Cov(Col2, Col1) = 2 + // Expected: + // 2, 2 + // 2, 2 + let m = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2); + let cov_mat = covariance_matrix(&m, Axis::Col); + + assert!((cov_mat.get(0, 0) - 2.0).abs() < EPS); + assert!((cov_mat.get(0, 1) - 2.0).abs() < EPS); + assert!((cov_mat.get(1, 0) - 2.0).abs() < EPS); + assert!((cov_mat.get(1, 1) - 2.0).abs() < EPS); + } + + #[test] + fn test_covariance_matrix_horizontal() { + // Test with a simple 2x2 matrix + // M = + // 1, 2 + // 3, 4 + // Expected covariance matrix (horizontal, i.e., between rows): + // Row1: [1, 2], mean = 1.5 + // Row2: [3, 4], mean = 3.5 + // Cov(Row1, Row1) = ((1-1.5)^2 + (2-1.5)^2) / (2-1) = (0.25+0.25)/1 = 0.5 + // Cov(Row2, Row2) = ((3-3.5)^2 + (4-3.5)^2) / (2-1) = (0.25+0.25)/1 = 0.5 + // Cov(Row1, Row2) = ((1-1.5)*(3-3.5) + (2-1.5)*(4-3.5)) / (2-1) = ((-0.5)*(-0.5) + (0.5)*(0.5))/1 = (0.25+0.25)/1 = 0.5 + // Cov(Row2, Row1) = 0.5 + // Expected: + // 0.5, -0.5 + // -0.5, 0.5 + let m = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2); + let cov_mat = covariance_matrix(&m, Axis::Row); + + assert!((cov_mat.get(0, 0) - 0.5).abs() < EPS); + assert!((cov_mat.get(0, 1) - (-0.5)).abs() < EPS); + assert!((cov_mat.get(1, 0) - (-0.5)).abs() < EPS); + assert!((cov_mat.get(1, 1) - 0.5).abs() < EPS); + } }