mirror of
https://github.com/Magnus167/rustframe.git
synced 2025-08-20 04:19:59 +00:00
Compare commits
2 Commits
727ec91a8b
...
4adcfd0ccb
Author | SHA1 | Date | |
---|---|---|---|
![]() |
4adcfd0ccb | ||
![]() |
bda9b84987 |
@ -14,14 +14,22 @@ impl KMeans {
|
|||||||
let n = x.cols();
|
let n = x.cols();
|
||||||
assert!(k <= m, "k must be ≤ number of samples");
|
assert!(k <= m, "k must be ≤ number of samples");
|
||||||
|
|
||||||
// ----- initialise centroids: pick k distinct rows at random -----
|
// ----- initialise centroids -----
|
||||||
let mut rng = rng(); // Changed from thread_rng()
|
|
||||||
let mut indices: Vec<usize> = (0..m).collect();
|
|
||||||
indices.shuffle(&mut rng);
|
|
||||||
let mut centroids = Matrix::zeros(k, n);
|
let mut centroids = Matrix::zeros(k, n);
|
||||||
for (c, &i) in indices[..k].iter().enumerate() {
|
if k == 1 {
|
||||||
|
// For k=1, initialize the centroid to the mean of the data
|
||||||
for j in 0..n {
|
for j in 0..n {
|
||||||
centroids[(c, j)] = x[(i, j)];
|
centroids[(0, j)] = x.column(j).iter().sum::<f64>() / m as f64;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// For k > 1, pick k distinct rows at random
|
||||||
|
let mut rng = rng(); // Changed from thread_rng()
|
||||||
|
let mut indices: Vec<usize> = (0..m).collect();
|
||||||
|
indices.shuffle(&mut rng);
|
||||||
|
for (c, &i) in indices[..k].iter().enumerate() {
|
||||||
|
for j in 0..n {
|
||||||
|
centroids[(c, j)] = x[(i, j)];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user