From bda9b849877985f15ae082b95e60a237a344d00c Mon Sep 17 00:00:00 2001 From: Palash Tyagi <23239946+Magnus167@users.noreply.github.com> Date: Sun, 13 Jul 2025 00:16:29 +0100 Subject: [PATCH] Refactor KMeans centroid initialization to handle k=1 case by setting centroid to mean of data --- src/compute/models/k_means.rs | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/src/compute/models/k_means.rs b/src/compute/models/k_means.rs index a92e414..7e66014 100644 --- a/src/compute/models/k_means.rs +++ b/src/compute/models/k_means.rs @@ -14,14 +14,22 @@ impl KMeans { let n = x.cols(); assert!(k <= m, "k must be ≤ number of samples"); - // ----- initialise centroids: pick k distinct rows at random ----- - let mut rng = rng(); // Changed from thread_rng() - let mut indices: Vec = (0..m).collect(); - indices.shuffle(&mut rng); + // ----- initialise centroids ----- let mut centroids = Matrix::zeros(k, n); - for (c, &i) in indices[..k].iter().enumerate() { + if k == 1 { + // For k=1, initialize the centroid to the mean of the data for j in 0..n { - centroids[(c, j)] = x[(i, j)]; + centroids[(0, j)] = x.column(j).iter().sum::() / m as f64; + } + } else { + // For k > 1, pick k distinct rows at random + let mut rng = rng(); // Changed from thread_rng() + let mut indices: Vec = (0..m).collect(); + indices.shuffle(&mut rng); + for (c, &i) in indices[..k].iter().enumerate() { + for j in 0..n { + centroids[(c, j)] = x[(i, j)]; + } } }