mirror of
https://github.com/Magnus167/rustframe.git
synced 2025-08-20 04:00:01 +00:00
Refactor KMeans centroid initialization and improve handling of edge cases
This commit is contained in:
parent
bda9b84987
commit
19bc09fd5a
@ -1,6 +1,6 @@
|
|||||||
|
use crate::compute::stats::mean_vertical;
|
||||||
use crate::matrix::Matrix;
|
use crate::matrix::Matrix;
|
||||||
use crate::matrix::{FloatMatrix, SeriesOps};
|
use rand::rng;
|
||||||
use rand::rng; // Changed from rand::thread_rng
|
|
||||||
use rand::seq::SliceRandom;
|
use rand::seq::SliceRandom;
|
||||||
|
|
||||||
pub struct KMeans {
|
pub struct KMeans {
|
||||||
@ -16,50 +16,50 @@ impl KMeans {
|
|||||||
|
|
||||||
// ----- initialise centroids -----
|
// ----- initialise centroids -----
|
||||||
let mut centroids = Matrix::zeros(k, n);
|
let mut centroids = Matrix::zeros(k, n);
|
||||||
if k == 1 {
|
if k > 0 && m > 0 {
|
||||||
// For k=1, initialize the centroid to the mean of the data
|
// case for empty data
|
||||||
for j in 0..n {
|
if k == 1 {
|
||||||
centroids[(0, j)] = x.column(j).iter().sum::<f64>() / m as f64;
|
let mean = mean_vertical(x);
|
||||||
}
|
centroids.row_copy_from_slice(0, &mean.data()); // ideally, data.row(0), but thats the same
|
||||||
} else {
|
} else {
|
||||||
// For k > 1, pick k distinct rows at random
|
// For k > 1, pick k distinct rows at random
|
||||||
let mut rng = rng(); // Changed from thread_rng()
|
let mut rng = rng();
|
||||||
let mut indices: Vec<usize> = (0..m).collect();
|
let mut indices: Vec<usize> = (0..m).collect();
|
||||||
indices.shuffle(&mut rng);
|
indices.shuffle(&mut rng);
|
||||||
for (c, &i) in indices[..k].iter().enumerate() {
|
for c in 0..k {
|
||||||
for j in 0..n {
|
centroids.row_copy_from_slice(c, &x.row(indices[c]));
|
||||||
centroids[(c, j)] = x[(i, j)];
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut labels = vec![0usize; m];
|
let mut labels = vec![0usize; m];
|
||||||
let mut old_centroids = centroids.clone(); // Store initial centroids for first iteration's convergence check
|
let mut distances = vec![0.0f64; m];
|
||||||
|
|
||||||
for _iter in 0..max_iter {
|
for _iter in 0..max_iter {
|
||||||
// Renamed loop variable to _iter for clarity
|
|
||||||
// ----- assignment step -----
|
|
||||||
let mut changed = false;
|
let mut changed = false;
|
||||||
|
// ----- assignment step -----
|
||||||
for i in 0..m {
|
for i in 0..m {
|
||||||
let sample_row = x.row(i);
|
let sample_row = x.row(i);
|
||||||
let sample_matrix = FloatMatrix::from_rows_vec(sample_row, 1, n);
|
|
||||||
|
|
||||||
let mut best = 0usize;
|
let mut best = 0usize;
|
||||||
let mut best_dist_sq = f64::MAX;
|
let mut best_dist_sq = f64::MAX;
|
||||||
|
|
||||||
for c in 0..k {
|
for c in 0..k {
|
||||||
let centroid_row = old_centroids.row(c); // Use old_centroids for distance calculation
|
let centroid_row = centroids.row(c);
|
||||||
let centroid_matrix = FloatMatrix::from_rows_vec(centroid_row, 1, n);
|
|
||||||
|
|
||||||
let diff = &sample_matrix - ¢roid_matrix;
|
let dist_sq: f64 = sample_row
|
||||||
let sq_diff = &diff * &diff;
|
.iter()
|
||||||
let dist_sq = sq_diff.sum_horizontal()[0];
|
.zip(centroid_row.iter())
|
||||||
|
.map(|(a, b)| (a - b).powi(2))
|
||||||
|
.sum();
|
||||||
|
|
||||||
if dist_sq < best_dist_sq {
|
if dist_sq < best_dist_sq {
|
||||||
best_dist_sq = dist_sq;
|
best_dist_sq = dist_sq;
|
||||||
best = c;
|
best = c;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
distances[i] = best_dist_sq;
|
||||||
|
|
||||||
if labels[i] != best {
|
if labels[i] != best {
|
||||||
labels[i] = best;
|
labels[i] = best;
|
||||||
changed = true;
|
changed = true;
|
||||||
@ -67,8 +67,8 @@ impl KMeans {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ----- update step -----
|
// ----- update step -----
|
||||||
|
let mut new_centroids = Matrix::zeros(k, n);
|
||||||
let mut counts = vec![0usize; k];
|
let mut counts = vec![0usize; k];
|
||||||
let mut new_centroids = Matrix::zeros(k, n); // New centroids for this iteration
|
|
||||||
for i in 0..m {
|
for i in 0..m {
|
||||||
let c = labels[i];
|
let c = labels[i];
|
||||||
counts[c] += 1;
|
counts[c] += 1;
|
||||||
@ -76,8 +76,29 @@ impl KMeans {
|
|||||||
new_centroids[(c, j)] += x[(i, j)];
|
new_centroids[(c, j)] += x[(i, j)];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for c in 0..k {
|
for c in 0..k {
|
||||||
if counts[c] > 0 {
|
if counts[c] == 0 {
|
||||||
|
// This cluster is empty. Re-initialize its centroid to the point
|
||||||
|
// furthest from its assigned centroid to prevent the cluster from dying.
|
||||||
|
let mut furthest_point_idx = 0;
|
||||||
|
let mut max_dist_sq = 0.0;
|
||||||
|
for (i, &dist) in distances.iter().enumerate() {
|
||||||
|
if dist > max_dist_sq {
|
||||||
|
max_dist_sq = dist;
|
||||||
|
furthest_point_idx = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for j in 0..n {
|
||||||
|
new_centroids[(c, j)] = x[(furthest_point_idx, j)];
|
||||||
|
}
|
||||||
|
// Ensure this point isn't chosen again for another empty cluster in the same iteration.
|
||||||
|
if m > 0 {
|
||||||
|
distances[furthest_point_idx] = 0.0;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Normalize the centroid by the number of points in it.
|
||||||
for j in 0..n {
|
for j in 0..n {
|
||||||
new_centroids[(c, j)] /= counts[c] as f64;
|
new_centroids[(c, j)] /= counts[c] as f64;
|
||||||
}
|
}
|
||||||
@ -86,53 +107,47 @@ impl KMeans {
|
|||||||
|
|
||||||
// ----- convergence test -----
|
// ----- convergence test -----
|
||||||
if !changed {
|
if !changed {
|
||||||
|
centroids = new_centroids; // update before breaking
|
||||||
break; // assignments stable
|
break; // assignments stable
|
||||||
}
|
}
|
||||||
if tol > 0.0 {
|
|
||||||
// optional centroid-shift tolerance
|
|
||||||
let diff = &new_centroids - &old_centroids; // Calculate difference between new and old centroids
|
|
||||||
let sq_diff = &diff * &diff;
|
|
||||||
let shift = sq_diff.data().iter().sum::<f64>().sqrt(); // Sum all squared differences
|
|
||||||
|
|
||||||
|
let diff = &new_centroids - ¢roids;
|
||||||
|
centroids = new_centroids; // Update for the next iteration
|
||||||
|
|
||||||
|
if tol > 0.0 {
|
||||||
|
let sq_diff = &diff * &diff;
|
||||||
|
let shift = sq_diff.data().iter().sum::<f64>().sqrt();
|
||||||
if shift < tol {
|
if shift < tol {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
old_centroids = new_centroids; // Update old_centroids for next iteration
|
|
||||||
}
|
}
|
||||||
(
|
(Self { centroids }, labels)
|
||||||
Self {
|
|
||||||
centroids: old_centroids,
|
|
||||||
},
|
|
||||||
labels,
|
|
||||||
) // Return the final centroids
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Predict nearest centroid for each sample.
|
/// Predict nearest centroid for each sample.
|
||||||
pub fn predict(&self, x: &Matrix<f64>) -> Vec<usize> {
|
pub fn predict(&self, x: &Matrix<f64>) -> Vec<usize> {
|
||||||
let m = x.rows();
|
let m = x.rows();
|
||||||
let k = self.centroids.rows();
|
let k = self.centroids.rows();
|
||||||
let n = x.cols();
|
|
||||||
|
|
||||||
if m == 0 {
|
if m == 0 {
|
||||||
// Handle empty input matrix
|
|
||||||
return Vec::new();
|
return Vec::new();
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut labels = vec![0usize; m];
|
let mut labels = vec![0usize; m];
|
||||||
for i in 0..m {
|
for i in 0..m {
|
||||||
let sample_row = x.row(i);
|
let sample_row = x.row(i);
|
||||||
let sample_matrix = FloatMatrix::from_rows_vec(sample_row, 1, n);
|
|
||||||
|
|
||||||
let mut best = 0usize;
|
let mut best = 0usize;
|
||||||
let mut best_dist_sq = f64::MAX;
|
let mut best_dist_sq = f64::MAX;
|
||||||
|
|
||||||
for c in 0..k {
|
for c in 0..k {
|
||||||
let centroid_row = self.centroids.row(c);
|
let centroid_row = self.centroids.row(c);
|
||||||
let centroid_matrix = FloatMatrix::from_rows_vec(centroid_row, 1, n);
|
|
||||||
|
|
||||||
let diff = &sample_matrix - ¢roid_matrix;
|
let dist_sq: f64 = sample_row
|
||||||
let sq_diff = &diff * &diff;
|
.iter()
|
||||||
let dist_sq = sq_diff.sum_horizontal()[0];
|
.zip(centroid_row.iter())
|
||||||
|
.map(|(a, b)| (a - b).powi(2))
|
||||||
|
.sum();
|
||||||
|
|
||||||
if dist_sq < best_dist_sq {
|
if dist_sq < best_dist_sq {
|
||||||
best_dist_sq = dist_sq;
|
best_dist_sq = dist_sq;
|
||||||
@ -236,10 +251,16 @@ mod tests {
|
|||||||
assert_eq!(kmeans_model.centroids.rows(), k);
|
assert_eq!(kmeans_model.centroids.rows(), k);
|
||||||
assert_eq!(labels.len(), x.rows());
|
assert_eq!(labels.len(), x.rows());
|
||||||
|
|
||||||
// Each sample should be its own cluster, so labels should be unique
|
// Each sample should be its own cluster. Due to random init, labels
|
||||||
|
// might not be [0,1,2,3,4] but will be a permutation of it.
|
||||||
let mut sorted_labels = labels.clone();
|
let mut sorted_labels = labels.clone();
|
||||||
sorted_labels.sort_unstable();
|
sorted_labels.sort_unstable();
|
||||||
assert_eq!(sorted_labels, vec![0, 1, 2, 3, 4]);
|
sorted_labels.dedup();
|
||||||
|
assert_eq!(
|
||||||
|
sorted_labels.len(),
|
||||||
|
k,
|
||||||
|
"Labels should all be unique when k==m"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -259,7 +280,7 @@ mod tests {
|
|||||||
let x = create_simple_integer_data(); // Use integer data
|
let x = create_simple_integer_data(); // Use integer data
|
||||||
let k = 1;
|
let k = 1;
|
||||||
let max_iter = 100;
|
let max_iter = 100;
|
||||||
let tol = 1e-6; // Reset tolerance
|
let tol = 1e-6;
|
||||||
|
|
||||||
let (kmeans_model, labels) = KMeans::fit(&x, k, max_iter, tol);
|
let (kmeans_model, labels) = KMeans::fit(&x, k, max_iter, tol);
|
||||||
|
|
||||||
@ -273,9 +294,8 @@ mod tests {
|
|||||||
let expected_centroid_x = x.column(0).iter().sum::<f64>() / x.rows() as f64;
|
let expected_centroid_x = x.column(0).iter().sum::<f64>() / x.rows() as f64;
|
||||||
let expected_centroid_y = x.column(1).iter().sum::<f64>() / x.rows() as f64;
|
let expected_centroid_y = x.column(1).iter().sum::<f64>() / x.rows() as f64;
|
||||||
|
|
||||||
// Relax the assertion tolerance to match the algorithm's convergence tolerance
|
assert!((kmeans_model.centroids[(0, 0)] - expected_centroid_x).abs() < 1e-9);
|
||||||
assert!((kmeans_model.centroids[(0, 0)] - expected_centroid_x).abs() < 1e-6);
|
assert!((kmeans_model.centroids[(0, 1)] - expected_centroid_y).abs() < 1e-9);
|
||||||
assert!((kmeans_model.centroids[(0, 1)] - expected_centroid_y).abs() < 1e-6);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -285,7 +305,8 @@ mod tests {
|
|||||||
let tol = 1e-6;
|
let tol = 1e-6;
|
||||||
let (kmeans_model, _labels) = KMeans::fit(&x, k, max_iter, tol);
|
let (kmeans_model, _labels) = KMeans::fit(&x, k, max_iter, tol);
|
||||||
|
|
||||||
// Create a 0x0 matrix. This is allowed by Matrix constructor.
|
// The `Matrix` type not support 0xN or Nx0 matrices.
|
||||||
|
// test with a 0x0 matrix is a valid edge case.
|
||||||
let empty_x = FloatMatrix::from_rows_vec(vec![], 0, 0);
|
let empty_x = FloatMatrix::from_rows_vec(vec![], 0, 0);
|
||||||
let predicted_labels = kmeans_model.predict(&empty_x);
|
let predicted_labels = kmeans_model.predict(&empty_x);
|
||||||
assert!(predicted_labels.is_empty());
|
assert!(predicted_labels.is_empty());
|
||||||
|
Loading…
x
Reference in New Issue
Block a user