Compare commits

...

2 Commits

View File

@ -14,14 +14,22 @@ impl KMeans {
let n = x.cols();
assert!(k <= m, "k must be ≤ number of samples");
// ----- initialise centroids: pick k distinct rows at random -----
let mut rng = rng(); // Changed from thread_rng()
let mut indices: Vec<usize> = (0..m).collect();
indices.shuffle(&mut rng);
// ----- initialise centroids -----
let mut centroids = Matrix::zeros(k, n);
for (c, &i) in indices[..k].iter().enumerate() {
if k == 1 {
// For k=1, initialize the centroid to the mean of the data
for j in 0..n {
centroids[(c, j)] = x[(i, j)];
centroids[(0, j)] = x.column(j).iter().sum::<f64>() / m as f64;
}
} else {
// For k > 1, pick k distinct rows at random
let mut rng = rng(); // Changed from thread_rng()
let mut indices: Vec<usize> = (0..m).collect();
indices.shuffle(&mut rng);
for (c, &i) in indices[..k].iter().enumerate() {
for j in 0..n {
centroids[(c, j)] = x[(i, j)];
}
}
}