diff --git a/src/compute/models/k_means.rs b/src/compute/models/k_means.rs index b76fe51..83412e3 100644 --- a/src/compute/models/k_means.rs +++ b/src/compute/models/k_means.rs @@ -162,6 +162,45 @@ impl KMeans { #[cfg(test)] mod tests { + #[test] + fn test_k_means_empty_cluster_reinit_centroid() { + // Try multiple times to increase the chance of hitting the empty cluster case + for _ in 0..20 { + let data = vec![0.0, 0.0, 0.0, 0.0, 10.0, 10.0]; + let x = FloatMatrix::from_rows_vec(data, 3, 2); + let k = 2; + let max_iter = 10; + let tol = 1e-6; + + let (kmeans_model, labels) = KMeans::fit(&x, k, max_iter, tol); + + // Check if any cluster is empty + let mut counts = vec![0; k]; + for &label in &labels { + counts[label] += 1; + } + if counts.iter().any(|&c| c == 0) { + // Only check the property for clusters that are empty + let centroids = kmeans_model.centroids; + for c in 0..k { + if counts[c] == 0 { + let mut matches_data_point = false; + for i in 0..3 { + let dx = centroids[(c, 0)] - x[(i, 0)]; + let dy = centroids[(c, 1)] - x[(i, 1)]; + if dx.abs() < 1e-9 && dy.abs() < 1e-9 { + matches_data_point = true; + break; + } + } + assert!(matches_data_point, "Centroid {} (empty cluster) does not match any data point", c); + } + } + break; + } + } + // If we never saw an empty cluster, that's fine; the test passes as long as no panic occurred + } use super::*; use crate::matrix::FloatMatrix;