Refactor KMeans fit and predict methods for improved clarity and performance

This commit is contained in:
Palash Tyagi 2025-07-12 01:45:59 +01:00
parent 049dd02c1a
commit 12a72317e4

View File

@ -1,4 +1,6 @@
use crate::matrix::Matrix;
use crate::matrix::{FloatMatrix, SeriesOps};
use rand::rng; // Changed from rand::thread_rng
use rand::seq::SliceRandom;
pub struct KMeans {
@ -13,7 +15,7 @@ impl KMeans {
assert!(k <= m, "k must be ≤ number of samples");
// ----- initialise centroids: pick k distinct rows at random -----
let mut rng = rand::rng();
let mut rng = rng(); // Changed from thread_rng()
let mut indices: Vec<usize> = (0..m).collect();
indices.shuffle(&mut rng);
let mut centroids = Matrix::zeros(k, n);
@ -24,20 +26,29 @@ impl KMeans {
}
let mut labels = vec![0usize; m];
for _ in 0..max_iter {
let mut old_centroids = centroids.clone(); // Store initial centroids for first iteration's convergence check
for _iter in 0..max_iter {
// Renamed loop variable to _iter for clarity
// ----- assignment step -----
let mut changed = false;
for i in 0..m {
let sample_row = x.row(i);
let sample_matrix = FloatMatrix::from_rows_vec(sample_row, 1, n);
let mut best = 0usize;
let mut best_dist = f64::MAX;
let mut best_dist_sq = f64::MAX;
for c in 0..k {
let mut dist = 0.0;
for j in 0..n {
let d = x[(i, j)] - centroids[(c, j)];
dist += d * d;
}
if dist < best_dist {
best_dist = dist;
let centroid_row = old_centroids.row(c); // Use old_centroids for distance calculation
let centroid_matrix = FloatMatrix::from_rows_vec(centroid_row, 1, n);
let diff = &sample_matrix - &centroid_matrix;
let sq_diff = &diff * &diff;
let dist_sq = sq_diff.sum_horizontal()[0];
if dist_sq < best_dist_sq {
best_dist_sq = dist_sq;
best = c;
}
}
@ -49,18 +60,18 @@ impl KMeans {
// ----- update step -----
let mut counts = vec![0usize; k];
let mut centroids = Matrix::zeros(k, n);
let mut new_centroids = Matrix::zeros(k, n); // New centroids for this iteration
for i in 0..m {
let c = labels[i];
counts[c] += 1;
for j in 0..n {
centroids[(c, j)] += x[(i, j)];
new_centroids[(c, j)] += x[(i, j)];
}
}
for c in 0..k {
if counts[c] > 0 {
for j in 0..n {
centroids[(c, j)] /= counts[c] as f64;
new_centroids[(c, j)] /= counts[c] as f64;
}
}
}
@ -71,19 +82,22 @@ impl KMeans {
}
if tol > 0.0 {
// optional centroid-shift tolerance
let mut shift: f64 = 0.0;
for c in 0..k {
for j in 0..n {
let d = centroids[(c, j)] - centroids[(c, j)]; // previous stored?
shift += d * d;
}
}
if shift.sqrt() < tol {
let diff = &new_centroids - &old_centroids; // Calculate difference between new and old centroids
let sq_diff = &diff * &diff;
let shift = sq_diff.data().iter().sum::<f64>().sqrt(); // Sum all squared differences
if shift < tol {
break;
}
}
old_centroids = new_centroids; // Update old_centroids for next iteration
}
(Self { centroids }, labels)
(
Self {
centroids: old_centroids,
},
labels,
) // Return the final centroids
}
/// Predict nearest centroid for each sample.
@ -91,18 +105,29 @@ impl KMeans {
let m = x.rows();
let k = self.centroids.rows();
let n = x.cols();
if m == 0 {
// Handle empty input matrix
return Vec::new();
}
let mut labels = vec![0usize; m];
for i in 0..m {
let sample_row = x.row(i);
let sample_matrix = FloatMatrix::from_rows_vec(sample_row, 1, n);
let mut best = 0usize;
let mut best_dist = f64::MAX;
let mut best_dist_sq = f64::MAX;
for c in 0..k {
let mut dist = 0.0;
for j in 0..n {
let d = x[(i, j)] - self.centroids[(c, j)];
dist += d * d;
}
if dist < best_dist {
best_dist = dist;
let centroid_row = self.centroids.row(c);
let centroid_matrix = FloatMatrix::from_rows_vec(centroid_row, 1, n);
let diff = &sample_matrix - &centroid_matrix;
let sq_diff = &diff * &diff;
let dist_sq = sq_diff.sum_horizontal()[0];
if dist_sq < best_dist_sq {
best_dist_sq = dist_sq;
best = c;
}
}
@ -111,3 +136,162 @@ impl KMeans {
labels
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::matrix::FloatMatrix;
fn create_test_data() -> (FloatMatrix, usize) {
// Simple 2D data for testing K-Means
// Cluster 1: (1,1), (1.5,1.5)
// Cluster 2: (5,8), (8,8), (6,7)
let data = vec![
1.0, 1.0, // Sample 0
1.5, 1.5, // Sample 1
5.0, 8.0, // Sample 2
8.0, 8.0, // Sample 3
6.0, 7.0, // Sample 4
];
let x = FloatMatrix::from_rows_vec(data, 5, 2);
let k = 2;
(x, k)
}
// Helper for single cluster test with exact mean
fn create_simple_integer_data() -> FloatMatrix {
// Data points: (1,1), (2,2), (3,3)
FloatMatrix::from_rows_vec(vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0], 3, 2)
}
#[test]
fn test_k_means_fit_predict_basic() {
let (x, k) = create_test_data();
let max_iter = 100;
let tol = 1e-6;
let (kmeans_model, labels) = KMeans::fit(&x, k, max_iter, tol);
// Assertions for fit
assert_eq!(kmeans_model.centroids.rows(), k);
assert_eq!(kmeans_model.centroids.cols(), x.cols());
assert_eq!(labels.len(), x.rows());
// Check if labels are within expected range (0 to k-1)
for &label in &labels {
assert!(label < k);
}
// Predict with the same data
let predicted_labels = kmeans_model.predict(&x);
// The exact labels might vary due to random initialization,
// but the clustering should be consistent.
// We expect two clusters. Let's check if samples 0,1 are in one cluster
// and samples 2,3,4 are in another.
let cluster_0_members = vec![labels[0], labels[1]];
let cluster_1_members = vec![labels[2], labels[3], labels[4]];
// All members of cluster 0 should have the same label
assert_eq!(cluster_0_members[0], cluster_0_members[1]);
// All members of cluster 1 should have the same label
assert_eq!(cluster_1_members[0], cluster_1_members[1]);
assert_eq!(cluster_1_members[0], cluster_1_members[2]);
// The two clusters should have different labels
assert_ne!(cluster_0_members[0], cluster_1_members[0]);
// Check predicted labels are consistent with fitted labels
assert_eq!(labels, predicted_labels);
// Test with a new sample
let new_sample_data = vec![1.2, 1.3]; // Should be close to cluster 0
let new_sample = FloatMatrix::from_rows_vec(new_sample_data, 1, 2);
let new_sample_label = kmeans_model.predict(&new_sample)[0];
assert_eq!(new_sample_label, cluster_0_members[0]);
let new_sample_data_2 = vec![7.0, 7.5]; // Should be close to cluster 1
let new_sample_2 = FloatMatrix::from_rows_vec(new_sample_data_2, 1, 2);
let new_sample_label_2 = kmeans_model.predict(&new_sample_2)[0];
assert_eq!(new_sample_label_2, cluster_1_members[0]);
}
#[test]
fn test_k_means_fit_k_equals_m() {
// Test case where k (number of clusters) equals m (number of samples)
let (x, _) = create_test_data(); // 5 samples
let k = 5; // 5 clusters
let max_iter = 10;
let tol = 1e-6;
let (kmeans_model, labels) = KMeans::fit(&x, k, max_iter, tol);
assert_eq!(kmeans_model.centroids.rows(), k);
assert_eq!(labels.len(), x.rows());
// Each sample should be its own cluster, so labels should be unique
let mut sorted_labels = labels.clone();
sorted_labels.sort_unstable();
assert_eq!(sorted_labels, vec![0, 1, 2, 3, 4]);
}
#[test]
#[should_panic(expected = "k must be ≤ number of samples")]
fn test_k_means_fit_k_greater_than_m() {
let (x, _) = create_test_data(); // 5 samples
let k = 6; // k > m
let max_iter = 10;
let tol = 1e-6;
let (_kmeans_model, _labels) = KMeans::fit(&x, k, max_iter, tol);
}
#[test]
fn test_k_means_fit_single_cluster() {
// Test with k=1
let x = create_simple_integer_data(); // Use integer data
let k = 1;
let max_iter = 100;
let tol = 1e-6; // Reset tolerance
let (kmeans_model, labels) = KMeans::fit(&x, k, max_iter, tol);
assert_eq!(kmeans_model.centroids.rows(), 1);
assert_eq!(labels.len(), x.rows());
// All labels should be 0
assert!(labels.iter().all(|&l| l == 0));
// Centroid should be the mean of all data points
let expected_centroid_x = x.column(0).iter().sum::<f64>() / x.rows() as f64;
let expected_centroid_y = x.column(1).iter().sum::<f64>() / x.rows() as f64;
assert!((kmeans_model.centroids[(0, 0)] - expected_centroid_x).abs() < 1e-9);
assert!((kmeans_model.centroids[(0, 1)] - expected_centroid_y).abs() < 1e-9);
}
#[test]
fn test_k_means_predict_empty_matrix() {
let (x, k) = create_test_data();
let max_iter = 10;
let tol = 1e-6;
let (kmeans_model, _labels) = KMeans::fit(&x, k, max_iter, tol);
// Create a 0x0 matrix. This is allowed by Matrix constructor.
let empty_x = FloatMatrix::from_rows_vec(vec![], 0, 0);
let predicted_labels = kmeans_model.predict(&empty_x);
assert!(predicted_labels.is_empty());
}
#[test]
fn test_k_means_predict_single_sample() {
let (x, k) = create_test_data();
let max_iter = 10;
let tol = 1e-6;
let (kmeans_model, _labels) = KMeans::fit(&x, k, max_iter, tol);
let single_sample = FloatMatrix::from_rows_vec(vec![1.1, 1.2], 1, 2);
let predicted_label = kmeans_model.predict(&single_sample);
assert_eq!(predicted_label.len(), 1);
assert!(predicted_label[0] < k);
}
}