Refactor covariance_matrix to improve mean calculation and add broadcasting for centered data; add tests for vertical and horizontal covariance matrices

This commit is contained in:
Palash Tyagi 2025-07-12 00:50:14 +01:00
parent b7480b20d4
commit 10018f7efe

View File

@ -82,16 +82,30 @@ pub fn covariance_horizontal(x: &Matrix<f64>) -> Matrix<f64> {
/// Calculates the covariance matrix of the input data. /// Calculates the covariance matrix of the input data.
/// Assumes input `x` is (n_samples, n_features). /// Assumes input `x` is (n_samples, n_features).
pub fn covariance_matrix(x: &Matrix<f64>, axis: Axis) -> Matrix<f64> { pub fn covariance_matrix(x: &Matrix<f64>, axis: Axis) -> Matrix<f64> {
let (n_samples, _n_features) = x.shape(); let (n_samples, n_features) = x.shape();
let mean_matrix = match axis { let centered_data = match axis {
Axis::Col => mean_vertical(x), // Mean of each feature (column) Axis::Col => {
Axis::Row => mean_horizontal(x), // Mean of each sample (row) let mean_matrix = mean_vertical(x); // 1 x n_features
x.zip(
&mean_matrix.broadcast_row_to_target_shape(n_samples, n_features),
|val, m| val - m,
)
}
Axis::Row => {
let mean_matrix = mean_horizontal(x); // n_samples x 1
// Manually create a matrix by broadcasting the column vector across columns
let mut broadcasted_mean = Matrix::zeros(n_samples, n_features);
for r in 0..n_samples {
let mean_val = mean_matrix.get(r, 0);
for c in 0..n_features {
*broadcasted_mean.get_mut(r, c) = *mean_val;
}
}
x.zip(&broadcasted_mean, |val, m| val - m)
}
}; };
// Center the data
let centered_data = x.zip(&mean_matrix.broadcast_row_to_target_shape(n_samples, x.cols()), |val, m| val - m);
// Calculate covariance matrix: (X_centered^T * X_centered) / (n_samples - 1) // Calculate covariance matrix: (X_centered^T * X_centered) / (n_samples - 1)
// If x is (n_samples, n_features), then centered_data is (n_samples, n_features) // If x is (n_samples, n_features), then centered_data is (n_samples, n_features)
// centered_data.transpose() is (n_features, n_samples) // centered_data.transpose() is (n_features, n_samples)
@ -148,13 +162,7 @@ mod tests {
// Expect 2x2 matrix of all 1.0 // Expect 2x2 matrix of all 1.0
for i in 0..2 { for i in 0..2 {
for j in 0..2 { for j in 0..2 {
assert!( assert!((cov_mat.get(i, j) - 1.0).abs() < EPS);
(cov_mat.get(i, j) - 1.0).abs() < EPS,
"cov_mat[{},{}] = {}",
i,
j,
cov_mat.get(i, j)
);
} }
} }
} }
@ -171,14 +179,58 @@ mod tests {
// Expect 2x2 matrix of all 0.25 // Expect 2x2 matrix of all 0.25
for i in 0..2 { for i in 0..2 {
for j in 0..2 { for j in 0..2 {
assert!( assert!((cov_mat.get(i, j) - 0.25).abs() < EPS);
(cov_mat.get(i, j) - 0.25).abs() < EPS,
"cov_mat[{},{}] = {}",
i,
j,
cov_mat.get(i, j)
);
} }
} }
} }
#[test]
fn test_covariance_matrix_vertical() {
// Test with a simple 2x2 matrix
// M =
// 1, 2
// 3, 4
// Expected covariance matrix (vertical, i.e., between columns):
// Col1: [1, 3], mean = 2
// Col2: [2, 4], mean = 3
// Cov(Col1, Col1) = ((1-2)^2 + (3-2)^2) / (2-1) = (1+1)/1 = 2
// Cov(Col2, Col2) = ((2-3)^2 + (4-3)^2) / (2-1) = (1+1)/1 = 2
// Cov(Col1, Col2) = ((1-2)*(2-3) + (3-2)*(4-3)) / (2-1) = ((-1)*(-1) + (1)*(1))/1 = (1+1)/1 = 2
// Cov(Col2, Col1) = 2
// Expected:
// 2, 2
// 2, 2
let m = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
let cov_mat = covariance_matrix(&m, Axis::Col);
assert!((cov_mat.get(0, 0) - 2.0).abs() < EPS);
assert!((cov_mat.get(0, 1) - 2.0).abs() < EPS);
assert!((cov_mat.get(1, 0) - 2.0).abs() < EPS);
assert!((cov_mat.get(1, 1) - 2.0).abs() < EPS);
}
#[test]
fn test_covariance_matrix_horizontal() {
// Test with a simple 2x2 matrix
// M =
// 1, 2
// 3, 4
// Expected covariance matrix (horizontal, i.e., between rows):
// Row1: [1, 2], mean = 1.5
// Row2: [3, 4], mean = 3.5
// Cov(Row1, Row1) = ((1-1.5)^2 + (2-1.5)^2) / (2-1) = (0.25+0.25)/1 = 0.5
// Cov(Row2, Row2) = ((3-3.5)^2 + (4-3.5)^2) / (2-1) = (0.25+0.25)/1 = 0.5
// Cov(Row1, Row2) = ((1-1.5)*(3-3.5) + (2-1.5)*(4-3.5)) / (2-1) = ((-0.5)*(-0.5) + (0.5)*(0.5))/1 = (0.25+0.25)/1 = 0.5
// Cov(Row2, Row1) = 0.5
// Expected:
// 0.5, -0.5
// -0.5, 0.5
let m = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
let cov_mat = covariance_matrix(&m, Axis::Row);
assert!((cov_mat.get(0, 0) - 0.5).abs() < EPS);
assert!((cov_mat.get(0, 1) - (-0.5)).abs() < EPS);
assert!((cov_mat.get(1, 0) - (-0.5)).abs() < EPS);
assert!((cov_mat.get(1, 1) - 0.5).abs() < EPS);
}
} }