diff --git a/src/compute/models/k_means.rs b/src/compute/models/k_means.rs index a92e414..7e66014 100644 --- a/src/compute/models/k_means.rs +++ b/src/compute/models/k_means.rs @@ -14,14 +14,22 @@ impl KMeans { let n = x.cols(); assert!(k <= m, "k must be ≤ number of samples"); - // ----- initialise centroids: pick k distinct rows at random ----- - let mut rng = rng(); // Changed from thread_rng() - let mut indices: Vec = (0..m).collect(); - indices.shuffle(&mut rng); + // ----- initialise centroids ----- let mut centroids = Matrix::zeros(k, n); - for (c, &i) in indices[..k].iter().enumerate() { + if k == 1 { + // For k=1, initialize the centroid to the mean of the data for j in 0..n { - centroids[(c, j)] = x[(i, j)]; + centroids[(0, j)] = x.column(j).iter().sum::() / m as f64; + } + } else { + // For k > 1, pick k distinct rows at random + let mut rng = rng(); // Changed from thread_rng() + let mut indices: Vec = (0..m).collect(); + indices.shuffle(&mut rng); + for (c, &i) in indices[..k].iter().enumerate() { + for j in 0..n { + centroids[(c, j)] = x[(i, j)]; + } } }