From 19bc09fd5afdbec3031369b7a2d2bcdf856ae37a Mon Sep 17 00:00:00 2001 From: Palash Tyagi <23239946+Magnus167@users.noreply.github.com> Date: Sun, 13 Jul 2025 01:29:19 +0100 Subject: [PATCH] Refactor KMeans centroid initialization and improve handling of edge cases --- src/compute/models/k_means.rs | 129 ++++++++++++++++++++-------------- 1 file changed, 75 insertions(+), 54 deletions(-) diff --git a/src/compute/models/k_means.rs b/src/compute/models/k_means.rs index 7e66014..6dcd155 100644 --- a/src/compute/models/k_means.rs +++ b/src/compute/models/k_means.rs @@ -1,6 +1,6 @@ +use crate::compute::stats::mean_vertical; use crate::matrix::Matrix; -use crate::matrix::{FloatMatrix, SeriesOps}; -use rand::rng; // Changed from rand::thread_rng +use rand::rng; use rand::seq::SliceRandom; pub struct KMeans { @@ -16,50 +16,50 @@ impl KMeans { // ----- initialise centroids ----- let mut centroids = Matrix::zeros(k, n); - if k == 1 { - // For k=1, initialize the centroid to the mean of the data - for j in 0..n { - centroids[(0, j)] = x.column(j).iter().sum::() / m as f64; - } - } else { - // For k > 1, pick k distinct rows at random - let mut rng = rng(); // Changed from thread_rng() - let mut indices: Vec = (0..m).collect(); - indices.shuffle(&mut rng); - for (c, &i) in indices[..k].iter().enumerate() { - for j in 0..n { - centroids[(c, j)] = x[(i, j)]; + if k > 0 && m > 0 { + // case for empty data + if k == 1 { + let mean = mean_vertical(x); + centroids.row_copy_from_slice(0, &mean.data()); // ideally, data.row(0), but thats the same + } else { + // For k > 1, pick k distinct rows at random + let mut rng = rng(); + let mut indices: Vec = (0..m).collect(); + indices.shuffle(&mut rng); + for c in 0..k { + centroids.row_copy_from_slice(c, &x.row(indices[c])); } } } let mut labels = vec![0usize; m]; - let mut old_centroids = centroids.clone(); // Store initial centroids for first iteration's convergence check + let mut distances = vec![0.0f64; m]; for _iter in 0..max_iter { - // Renamed loop variable to _iter for clarity - // ----- assignment step ----- let mut changed = false; + // ----- assignment step ----- for i in 0..m { let sample_row = x.row(i); - let sample_matrix = FloatMatrix::from_rows_vec(sample_row, 1, n); - let mut best = 0usize; let mut best_dist_sq = f64::MAX; for c in 0..k { - let centroid_row = old_centroids.row(c); // Use old_centroids for distance calculation - let centroid_matrix = FloatMatrix::from_rows_vec(centroid_row, 1, n); + let centroid_row = centroids.row(c); - let diff = &sample_matrix - ¢roid_matrix; - let sq_diff = &diff * &diff; - let dist_sq = sq_diff.sum_horizontal()[0]; + let dist_sq: f64 = sample_row + .iter() + .zip(centroid_row.iter()) + .map(|(a, b)| (a - b).powi(2)) + .sum(); if dist_sq < best_dist_sq { best_dist_sq = dist_sq; best = c; } } + + distances[i] = best_dist_sq; + if labels[i] != best { labels[i] = best; changed = true; @@ -67,8 +67,8 @@ impl KMeans { } // ----- update step ----- + let mut new_centroids = Matrix::zeros(k, n); let mut counts = vec![0usize; k]; - let mut new_centroids = Matrix::zeros(k, n); // New centroids for this iteration for i in 0..m { let c = labels[i]; counts[c] += 1; @@ -76,8 +76,29 @@ impl KMeans { new_centroids[(c, j)] += x[(i, j)]; } } + for c in 0..k { - if counts[c] > 0 { + if counts[c] == 0 { + // This cluster is empty. Re-initialize its centroid to the point + // furthest from its assigned centroid to prevent the cluster from dying. + let mut furthest_point_idx = 0; + let mut max_dist_sq = 0.0; + for (i, &dist) in distances.iter().enumerate() { + if dist > max_dist_sq { + max_dist_sq = dist; + furthest_point_idx = i; + } + } + + for j in 0..n { + new_centroids[(c, j)] = x[(furthest_point_idx, j)]; + } + // Ensure this point isn't chosen again for another empty cluster in the same iteration. + if m > 0 { + distances[furthest_point_idx] = 0.0; + } + } else { + // Normalize the centroid by the number of points in it. for j in 0..n { new_centroids[(c, j)] /= counts[c] as f64; } @@ -86,53 +107,47 @@ impl KMeans { // ----- convergence test ----- if !changed { + centroids = new_centroids; // update before breaking break; // assignments stable } - if tol > 0.0 { - // optional centroid-shift tolerance - let diff = &new_centroids - &old_centroids; // Calculate difference between new and old centroids - let sq_diff = &diff * &diff; - let shift = sq_diff.data().iter().sum::().sqrt(); // Sum all squared differences + let diff = &new_centroids - ¢roids; + centroids = new_centroids; // Update for the next iteration + + if tol > 0.0 { + let sq_diff = &diff * &diff; + let shift = sq_diff.data().iter().sum::().sqrt(); if shift < tol { break; } } - old_centroids = new_centroids; // Update old_centroids for next iteration } - ( - Self { - centroids: old_centroids, - }, - labels, - ) // Return the final centroids + (Self { centroids }, labels) } /// Predict nearest centroid for each sample. pub fn predict(&self, x: &Matrix) -> Vec { let m = x.rows(); let k = self.centroids.rows(); - let n = x.cols(); if m == 0 { - // Handle empty input matrix return Vec::new(); } let mut labels = vec![0usize; m]; for i in 0..m { let sample_row = x.row(i); - let sample_matrix = FloatMatrix::from_rows_vec(sample_row, 1, n); - let mut best = 0usize; let mut best_dist_sq = f64::MAX; + for c in 0..k { let centroid_row = self.centroids.row(c); - let centroid_matrix = FloatMatrix::from_rows_vec(centroid_row, 1, n); - let diff = &sample_matrix - ¢roid_matrix; - let sq_diff = &diff * &diff; - let dist_sq = sq_diff.sum_horizontal()[0]; + let dist_sq: f64 = sample_row + .iter() + .zip(centroid_row.iter()) + .map(|(a, b)| (a - b).powi(2)) + .sum(); if dist_sq < best_dist_sq { best_dist_sq = dist_sq; @@ -236,10 +251,16 @@ mod tests { assert_eq!(kmeans_model.centroids.rows(), k); assert_eq!(labels.len(), x.rows()); - // Each sample should be its own cluster, so labels should be unique + // Each sample should be its own cluster. Due to random init, labels + // might not be [0,1,2,3,4] but will be a permutation of it. let mut sorted_labels = labels.clone(); sorted_labels.sort_unstable(); - assert_eq!(sorted_labels, vec![0, 1, 2, 3, 4]); + sorted_labels.dedup(); + assert_eq!( + sorted_labels.len(), + k, + "Labels should all be unique when k==m" + ); } #[test] @@ -259,7 +280,7 @@ mod tests { let x = create_simple_integer_data(); // Use integer data let k = 1; let max_iter = 100; - let tol = 1e-6; // Reset tolerance + let tol = 1e-6; let (kmeans_model, labels) = KMeans::fit(&x, k, max_iter, tol); @@ -273,9 +294,8 @@ mod tests { let expected_centroid_x = x.column(0).iter().sum::() / x.rows() as f64; let expected_centroid_y = x.column(1).iter().sum::() / x.rows() as f64; - // Relax the assertion tolerance to match the algorithm's convergence tolerance - assert!((kmeans_model.centroids[(0, 0)] - expected_centroid_x).abs() < 1e-6); - assert!((kmeans_model.centroids[(0, 1)] - expected_centroid_y).abs() < 1e-6); + assert!((kmeans_model.centroids[(0, 0)] - expected_centroid_x).abs() < 1e-9); + assert!((kmeans_model.centroids[(0, 1)] - expected_centroid_y).abs() < 1e-9); } #[test] @@ -285,7 +305,8 @@ mod tests { let tol = 1e-6; let (kmeans_model, _labels) = KMeans::fit(&x, k, max_iter, tol); - // Create a 0x0 matrix. This is allowed by Matrix constructor. + // The `Matrix` type not support 0xN or Nx0 matrices. + // test with a 0x0 matrix is a valid edge case. let empty_x = FloatMatrix::from_rows_vec(vec![], 0, 0); let predicted_labels = kmeans_model.predict(&empty_x); assert!(predicted_labels.is_empty());