mirror of
https://github.com/Magnus167/rustframe.git
synced 2025-11-19 17:26:09 +00:00
375 lines
13 KiB
Rust
375 lines
13 KiB
Rust
//! Simple k-means clustering working on [`Matrix`] data.
|
|
//!
|
|
//! ```
|
|
//! use rustframe::compute::models::k_means::KMeans;
|
|
//! use rustframe::matrix::Matrix;
|
|
//!
|
|
//! let data = Matrix::from_vec(vec![1.0, 1.0, 5.0, 5.0], 2, 2);
|
|
//! let (model, labels) = KMeans::fit(&data, 2, 10, 1e-4);
|
|
//! assert_eq!(model.centroids.rows(), 2);
|
|
//! assert_eq!(labels.len(), 2);
|
|
//! ```
|
|
use crate::compute::stats::mean_vertical;
|
|
use crate::matrix::Matrix;
|
|
use crate::random::prelude::*;
|
|
|
|
pub struct KMeans {
|
|
pub centroids: Matrix<f64>, // (k, n_features)
|
|
}
|
|
|
|
impl KMeans {
|
|
/// Fit with k clusters.
|
|
pub fn fit(x: &Matrix<f64>, k: usize, max_iter: usize, tol: f64) -> (Self, Vec<usize>) {
|
|
let m = x.rows();
|
|
let n = x.cols();
|
|
assert!(k <= m, "k must be ≤ number of samples");
|
|
|
|
// ----- initialise centroids -----
|
|
let mut centroids = Matrix::zeros(k, n);
|
|
if k > 0 && m > 0 {
|
|
// case for empty data
|
|
if k == 1 {
|
|
let mean = mean_vertical(x);
|
|
centroids.row_copy_from_slice(0, &mean.data()); // ideally, data.row(0), but thats the same
|
|
} else {
|
|
// For k > 1, pick k distinct rows at random
|
|
let mut rng = rng();
|
|
let mut indices: Vec<usize> = (0..m).collect();
|
|
indices.shuffle(&mut rng);
|
|
for c in 0..k {
|
|
centroids.row_copy_from_slice(c, &x.row(indices[c]));
|
|
}
|
|
}
|
|
}
|
|
|
|
let mut labels = vec![0usize; m];
|
|
let mut distances = vec![0.0f64; m];
|
|
|
|
for _iter in 0..max_iter {
|
|
let mut changed = false;
|
|
// ----- assignment step -----
|
|
for i in 0..m {
|
|
let sample_row = x.row(i);
|
|
let mut best = 0usize;
|
|
let mut best_dist_sq = f64::MAX;
|
|
|
|
for c in 0..k {
|
|
let centroid_row = centroids.row(c);
|
|
|
|
let dist_sq: f64 = sample_row
|
|
.iter()
|
|
.zip(centroid_row.iter())
|
|
.map(|(a, b)| (a - b).powi(2))
|
|
.sum();
|
|
|
|
if dist_sq < best_dist_sq {
|
|
best_dist_sq = dist_sq;
|
|
best = c;
|
|
}
|
|
}
|
|
|
|
distances[i] = best_dist_sq;
|
|
|
|
if labels[i] != best {
|
|
labels[i] = best;
|
|
changed = true;
|
|
}
|
|
}
|
|
|
|
// ----- update step -----
|
|
let mut new_centroids = Matrix::zeros(k, n);
|
|
let mut counts = vec![0usize; k];
|
|
for i in 0..m {
|
|
let c = labels[i];
|
|
counts[c] += 1;
|
|
for j in 0..n {
|
|
new_centroids[(c, j)] += x[(i, j)];
|
|
}
|
|
}
|
|
|
|
for c in 0..k {
|
|
if counts[c] == 0 {
|
|
// This cluster is empty. Re-initialize its centroid to the point
|
|
// furthest from its assigned centroid to prevent the cluster from dying.
|
|
let mut furthest_point_idx = 0;
|
|
let mut max_dist_sq = 0.0;
|
|
for (i, &dist) in distances.iter().enumerate() {
|
|
if dist > max_dist_sq {
|
|
max_dist_sq = dist;
|
|
furthest_point_idx = i;
|
|
}
|
|
}
|
|
|
|
for j in 0..n {
|
|
new_centroids[(c, j)] = x[(furthest_point_idx, j)];
|
|
}
|
|
// Ensure this point isn't chosen again for another empty cluster in the same iteration.
|
|
if m > 0 {
|
|
distances[furthest_point_idx] = 0.0;
|
|
}
|
|
} else {
|
|
// Normalize the centroid by the number of points in it.
|
|
for j in 0..n {
|
|
new_centroids[(c, j)] /= counts[c] as f64;
|
|
}
|
|
}
|
|
}
|
|
|
|
// ----- convergence test -----
|
|
if !changed {
|
|
centroids = new_centroids; // update before breaking
|
|
break; // assignments stable
|
|
}
|
|
|
|
let diff = &new_centroids - ¢roids;
|
|
centroids = new_centroids; // Update for the next iteration
|
|
|
|
if tol > 0.0 {
|
|
let sq_diff = &diff * &diff;
|
|
let shift = sq_diff.data().iter().sum::<f64>().sqrt();
|
|
if shift < tol {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
(Self { centroids }, labels)
|
|
}
|
|
|
|
/// Predict nearest centroid for each sample.
|
|
pub fn predict(&self, x: &Matrix<f64>) -> Vec<usize> {
|
|
let m = x.rows();
|
|
let k = self.centroids.rows();
|
|
|
|
if m == 0 {
|
|
return Vec::new();
|
|
}
|
|
|
|
let mut labels = vec![0usize; m];
|
|
for i in 0..m {
|
|
let sample_row = x.row(i);
|
|
let mut best = 0usize;
|
|
let mut best_dist_sq = f64::MAX;
|
|
|
|
for c in 0..k {
|
|
let centroid_row = self.centroids.row(c);
|
|
|
|
let dist_sq: f64 = sample_row
|
|
.iter()
|
|
.zip(centroid_row.iter())
|
|
.map(|(a, b)| (a - b).powi(2))
|
|
.sum();
|
|
|
|
if dist_sq < best_dist_sq {
|
|
best_dist_sq = dist_sq;
|
|
best = c;
|
|
}
|
|
}
|
|
labels[i] = best;
|
|
}
|
|
labels
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
#[test]
|
|
fn test_k_means_empty_cluster_reinit_centroid() {
|
|
// Try multiple times to increase the chance of hitting the empty cluster case
|
|
for _ in 0..20 {
|
|
let data = vec![0.0, 0.0, 0.0, 0.0, 10.0, 10.0];
|
|
let x = FloatMatrix::from_rows_vec(data, 3, 2);
|
|
let k = 2;
|
|
let max_iter = 10;
|
|
let tol = 1e-6;
|
|
|
|
let (kmeans_model, labels) = KMeans::fit(&x, k, max_iter, tol);
|
|
|
|
// Check if any cluster is empty
|
|
let mut counts = vec![0; k];
|
|
for &label in &labels {
|
|
counts[label] += 1;
|
|
}
|
|
if counts.iter().any(|&c| c == 0) {
|
|
// Only check the property for clusters that are empty
|
|
let centroids = kmeans_model.centroids;
|
|
for c in 0..k {
|
|
if counts[c] == 0 {
|
|
let mut matches_data_point = false;
|
|
for i in 0..3 {
|
|
let dx = centroids[(c, 0)] - x[(i, 0)];
|
|
let dy = centroids[(c, 1)] - x[(i, 1)];
|
|
if dx.abs() < 1e-9 && dy.abs() < 1e-9 {
|
|
matches_data_point = true;
|
|
break;
|
|
}
|
|
}
|
|
// "Centroid {} (empty cluster) does not match any data point",c
|
|
assert!(matches_data_point);
|
|
}
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
// If we never saw an empty cluster, that's fine; the test passes as long as no panic occurred
|
|
}
|
|
use super::*;
|
|
use crate::matrix::FloatMatrix;
|
|
|
|
fn create_test_data() -> (FloatMatrix, usize) {
|
|
// Simple 2D data for testing K-Means
|
|
// Cluster 1: (1,1), (1.5,1.5)
|
|
// Cluster 2: (5,8), (8,8), (6,7)
|
|
let data = vec![
|
|
1.0, 1.0, // Sample 0
|
|
1.5, 1.5, // Sample 1
|
|
5.0, 8.0, // Sample 2
|
|
8.0, 8.0, // Sample 3
|
|
6.0, 7.0, // Sample 4
|
|
];
|
|
let x = FloatMatrix::from_rows_vec(data, 5, 2);
|
|
let k = 2;
|
|
(x, k)
|
|
}
|
|
|
|
// Helper for single cluster test with exact mean
|
|
fn create_simple_integer_data() -> FloatMatrix {
|
|
// Data points: (1,1), (2,2), (3,3)
|
|
FloatMatrix::from_rows_vec(vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0], 3, 2)
|
|
}
|
|
|
|
#[test]
|
|
fn test_k_means_fit_predict_basic() {
|
|
let (x, k) = create_test_data();
|
|
let max_iter = 100;
|
|
let tol = 1e-6;
|
|
|
|
let (kmeans_model, labels) = KMeans::fit(&x, k, max_iter, tol);
|
|
|
|
// Assertions for fit
|
|
assert_eq!(kmeans_model.centroids.rows(), k);
|
|
assert_eq!(kmeans_model.centroids.cols(), x.cols());
|
|
assert_eq!(labels.len(), x.rows());
|
|
|
|
// Check if labels are within expected range (0 to k-1)
|
|
for &label in &labels {
|
|
assert!(label < k);
|
|
}
|
|
|
|
// Predict with the same data
|
|
let predicted_labels = kmeans_model.predict(&x);
|
|
|
|
// The exact labels might vary due to random initialization,
|
|
// but the clustering should be consistent.
|
|
// We expect two clusters. Let's check if samples 0,1 are in one cluster
|
|
// and samples 2,3,4 are in another.
|
|
let cluster_0_members = vec![labels[0], labels[1]];
|
|
let cluster_1_members = vec![labels[2], labels[3], labels[4]];
|
|
|
|
// All members of cluster 0 should have the same label
|
|
assert_eq!(cluster_0_members[0], cluster_0_members[1]);
|
|
// All members of cluster 1 should have the same label
|
|
assert_eq!(cluster_1_members[0], cluster_1_members[1]);
|
|
assert_eq!(cluster_1_members[0], cluster_1_members[2]);
|
|
// The two clusters should have different labels
|
|
assert_ne!(cluster_0_members[0], cluster_1_members[0]);
|
|
|
|
// Check predicted labels are consistent with fitted labels
|
|
assert_eq!(labels, predicted_labels);
|
|
|
|
// Test with a new sample
|
|
let new_sample_data = vec![1.2, 1.3]; // Should be close to cluster 0
|
|
let new_sample = FloatMatrix::from_rows_vec(new_sample_data, 1, 2);
|
|
let new_sample_label = kmeans_model.predict(&new_sample)[0];
|
|
assert_eq!(new_sample_label, cluster_0_members[0]);
|
|
|
|
let new_sample_data_2 = vec![7.0, 7.5]; // Should be close to cluster 1
|
|
let new_sample_2 = FloatMatrix::from_rows_vec(new_sample_data_2, 1, 2);
|
|
let new_sample_label_2 = kmeans_model.predict(&new_sample_2)[0];
|
|
assert_eq!(new_sample_label_2, cluster_1_members[0]);
|
|
}
|
|
|
|
#[test]
|
|
fn test_k_means_fit_k_equals_m() {
|
|
// Test case where k (number of clusters) equals m (number of samples)
|
|
let (x, _) = create_test_data(); // 5 samples
|
|
let k = 5; // 5 clusters
|
|
let max_iter = 10;
|
|
let tol = 1e-6;
|
|
|
|
let (kmeans_model, labels) = KMeans::fit(&x, k, max_iter, tol);
|
|
|
|
assert_eq!(kmeans_model.centroids.rows(), k);
|
|
assert_eq!(labels.len(), x.rows());
|
|
|
|
// Each sample should be its own cluster. Due to random init, labels
|
|
// might not be [0,1,2,3,4] but will be a permutation of it.
|
|
let mut sorted_labels = labels.clone();
|
|
sorted_labels.sort_unstable();
|
|
sorted_labels.dedup();
|
|
// Labels should all be unique when k==m
|
|
assert_eq!(sorted_labels.len(), k);
|
|
}
|
|
|
|
#[test]
|
|
#[should_panic(expected = "k must be ≤ number of samples")]
|
|
fn test_k_means_fit_k_greater_than_m() {
|
|
let (x, _) = create_test_data(); // 5 samples
|
|
let k = 6; // k > m
|
|
let max_iter = 10;
|
|
let tol = 1e-6;
|
|
|
|
let (_kmeans_model, _labels) = KMeans::fit(&x, k, max_iter, tol);
|
|
}
|
|
|
|
#[test]
|
|
fn test_k_means_fit_single_cluster() {
|
|
// Test with k=1
|
|
let x = create_simple_integer_data(); // Use integer data
|
|
let k = 1;
|
|
let max_iter = 100;
|
|
let tol = 1e-6;
|
|
|
|
let (kmeans_model, labels) = KMeans::fit(&x, k, max_iter, tol);
|
|
|
|
assert_eq!(kmeans_model.centroids.rows(), 1);
|
|
assert_eq!(labels.len(), x.rows());
|
|
|
|
// All labels should be 0
|
|
assert!(labels.iter().all(|&l| l == 0));
|
|
|
|
// Centroid should be the mean of all data points
|
|
let expected_centroid_x = x.column(0).iter().sum::<f64>() / x.rows() as f64;
|
|
let expected_centroid_y = x.column(1).iter().sum::<f64>() / x.rows() as f64;
|
|
|
|
assert!((kmeans_model.centroids[(0, 0)] - expected_centroid_x).abs() < 1e-9);
|
|
assert!((kmeans_model.centroids[(0, 1)] - expected_centroid_y).abs() < 1e-9);
|
|
}
|
|
|
|
#[test]
|
|
fn test_k_means_predict_empty_matrix() {
|
|
let (x, k) = create_test_data();
|
|
let max_iter = 10;
|
|
let tol = 1e-6;
|
|
let (kmeans_model, _labels) = KMeans::fit(&x, k, max_iter, tol);
|
|
|
|
// The `Matrix` type not support 0xN or Nx0 matrices.
|
|
// test with a 0x0 matrix is a valid edge case.
|
|
let empty_x = FloatMatrix::from_rows_vec(vec![], 0, 0);
|
|
let predicted_labels = kmeans_model.predict(&empty_x);
|
|
assert!(predicted_labels.is_empty());
|
|
}
|
|
|
|
#[test]
|
|
fn test_k_means_predict_single_sample() {
|
|
let (x, k) = create_test_data();
|
|
let max_iter = 10;
|
|
let tol = 1e-6;
|
|
let (kmeans_model, _labels) = KMeans::fit(&x, k, max_iter, tol);
|
|
|
|
let single_sample = FloatMatrix::from_rows_vec(vec![1.1, 1.2], 1, 2);
|
|
let predicted_label = kmeans_model.predict(&single_sample);
|
|
assert_eq!(predicted_label.len(), 1);
|
|
assert!(predicted_label[0] < k);
|
|
}
|
|
}
|