From 12a72317e4ed93ee54fce8ea0aff83b613067605 Mon Sep 17 00:00:00 2001 From: Palash Tyagi <23239946+Magnus167@users.noreply.github.com> Date: Sat, 12 Jul 2025 01:45:59 +0100 Subject: [PATCH] Refactor KMeans fit and predict methods for improved clarity and performance --- src/compute/models/k_means.rs | 244 +++++++++++++++++++++++++++++----- 1 file changed, 214 insertions(+), 30 deletions(-) diff --git a/src/compute/models/k_means.rs b/src/compute/models/k_means.rs index 7b7fcf0..fe67e8e 100644 --- a/src/compute/models/k_means.rs +++ b/src/compute/models/k_means.rs @@ -1,4 +1,6 @@ use crate::matrix::Matrix; +use crate::matrix::{FloatMatrix, SeriesOps}; +use rand::rng; // Changed from rand::thread_rng use rand::seq::SliceRandom; pub struct KMeans { @@ -13,7 +15,7 @@ impl KMeans { assert!(k <= m, "k must be ≤ number of samples"); // ----- initialise centroids: pick k distinct rows at random ----- - let mut rng = rand::rng(); + let mut rng = rng(); // Changed from thread_rng() let mut indices: Vec = (0..m).collect(); indices.shuffle(&mut rng); let mut centroids = Matrix::zeros(k, n); @@ -24,20 +26,29 @@ impl KMeans { } let mut labels = vec![0usize; m]; - for _ in 0..max_iter { + let mut old_centroids = centroids.clone(); // Store initial centroids for first iteration's convergence check + + for _iter in 0..max_iter { + // Renamed loop variable to _iter for clarity // ----- assignment step ----- let mut changed = false; for i in 0..m { + let sample_row = x.row(i); + let sample_matrix = FloatMatrix::from_rows_vec(sample_row, 1, n); + let mut best = 0usize; - let mut best_dist = f64::MAX; + let mut best_dist_sq = f64::MAX; + for c in 0..k { - let mut dist = 0.0; - for j in 0..n { - let d = x[(i, j)] - centroids[(c, j)]; - dist += d * d; - } - if dist < best_dist { - best_dist = dist; + let centroid_row = old_centroids.row(c); // Use old_centroids for distance calculation + let centroid_matrix = FloatMatrix::from_rows_vec(centroid_row, 1, n); + + let diff = &sample_matrix - ¢roid_matrix; + let sq_diff = &diff * &diff; + let dist_sq = sq_diff.sum_horizontal()[0]; + + if dist_sq < best_dist_sq { + best_dist_sq = dist_sq; best = c; } } @@ -49,18 +60,18 @@ impl KMeans { // ----- update step ----- let mut counts = vec![0usize; k]; - let mut centroids = Matrix::zeros(k, n); + let mut new_centroids = Matrix::zeros(k, n); // New centroids for this iteration for i in 0..m { let c = labels[i]; counts[c] += 1; for j in 0..n { - centroids[(c, j)] += x[(i, j)]; + new_centroids[(c, j)] += x[(i, j)]; } } for c in 0..k { if counts[c] > 0 { for j in 0..n { - centroids[(c, j)] /= counts[c] as f64; + new_centroids[(c, j)] /= counts[c] as f64; } } } @@ -71,19 +82,22 @@ impl KMeans { } if tol > 0.0 { // optional centroid-shift tolerance - let mut shift: f64 = 0.0; - for c in 0..k { - for j in 0..n { - let d = centroids[(c, j)] - centroids[(c, j)]; // previous stored? - shift += d * d; - } - } - if shift.sqrt() < tol { + let diff = &new_centroids - &old_centroids; // Calculate difference between new and old centroids + let sq_diff = &diff * &diff; + let shift = sq_diff.data().iter().sum::().sqrt(); // Sum all squared differences + + if shift < tol { break; } } + old_centroids = new_centroids; // Update old_centroids for next iteration } - (Self { centroids }, labels) + ( + Self { + centroids: old_centroids, + }, + labels, + ) // Return the final centroids } /// Predict nearest centroid for each sample. @@ -91,18 +105,29 @@ impl KMeans { let m = x.rows(); let k = self.centroids.rows(); let n = x.cols(); + + if m == 0 { + // Handle empty input matrix + return Vec::new(); + } + let mut labels = vec![0usize; m]; for i in 0..m { + let sample_row = x.row(i); + let sample_matrix = FloatMatrix::from_rows_vec(sample_row, 1, n); + let mut best = 0usize; - let mut best_dist = f64::MAX; + let mut best_dist_sq = f64::MAX; for c in 0..k { - let mut dist = 0.0; - for j in 0..n { - let d = x[(i, j)] - self.centroids[(c, j)]; - dist += d * d; - } - if dist < best_dist { - best_dist = dist; + let centroid_row = self.centroids.row(c); + let centroid_matrix = FloatMatrix::from_rows_vec(centroid_row, 1, n); + + let diff = &sample_matrix - ¢roid_matrix; + let sq_diff = &diff * &diff; + let dist_sq = sq_diff.sum_horizontal()[0]; + + if dist_sq < best_dist_sq { + best_dist_sq = dist_sq; best = c; } } @@ -111,3 +136,162 @@ impl KMeans { labels } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::matrix::FloatMatrix; + + fn create_test_data() -> (FloatMatrix, usize) { + // Simple 2D data for testing K-Means + // Cluster 1: (1,1), (1.5,1.5) + // Cluster 2: (5,8), (8,8), (6,7) + let data = vec![ + 1.0, 1.0, // Sample 0 + 1.5, 1.5, // Sample 1 + 5.0, 8.0, // Sample 2 + 8.0, 8.0, // Sample 3 + 6.0, 7.0, // Sample 4 + ]; + let x = FloatMatrix::from_rows_vec(data, 5, 2); + let k = 2; + (x, k) + } + + // Helper for single cluster test with exact mean + fn create_simple_integer_data() -> FloatMatrix { + // Data points: (1,1), (2,2), (3,3) + FloatMatrix::from_rows_vec(vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0], 3, 2) + } + + #[test] + fn test_k_means_fit_predict_basic() { + let (x, k) = create_test_data(); + let max_iter = 100; + let tol = 1e-6; + + let (kmeans_model, labels) = KMeans::fit(&x, k, max_iter, tol); + + // Assertions for fit + assert_eq!(kmeans_model.centroids.rows(), k); + assert_eq!(kmeans_model.centroids.cols(), x.cols()); + assert_eq!(labels.len(), x.rows()); + + // Check if labels are within expected range (0 to k-1) + for &label in &labels { + assert!(label < k); + } + + // Predict with the same data + let predicted_labels = kmeans_model.predict(&x); + + // The exact labels might vary due to random initialization, + // but the clustering should be consistent. + // We expect two clusters. Let's check if samples 0,1 are in one cluster + // and samples 2,3,4 are in another. + let cluster_0_members = vec![labels[0], labels[1]]; + let cluster_1_members = vec![labels[2], labels[3], labels[4]]; + + // All members of cluster 0 should have the same label + assert_eq!(cluster_0_members[0], cluster_0_members[1]); + // All members of cluster 1 should have the same label + assert_eq!(cluster_1_members[0], cluster_1_members[1]); + assert_eq!(cluster_1_members[0], cluster_1_members[2]); + // The two clusters should have different labels + assert_ne!(cluster_0_members[0], cluster_1_members[0]); + + // Check predicted labels are consistent with fitted labels + assert_eq!(labels, predicted_labels); + + // Test with a new sample + let new_sample_data = vec![1.2, 1.3]; // Should be close to cluster 0 + let new_sample = FloatMatrix::from_rows_vec(new_sample_data, 1, 2); + let new_sample_label = kmeans_model.predict(&new_sample)[0]; + assert_eq!(new_sample_label, cluster_0_members[0]); + + let new_sample_data_2 = vec![7.0, 7.5]; // Should be close to cluster 1 + let new_sample_2 = FloatMatrix::from_rows_vec(new_sample_data_2, 1, 2); + let new_sample_label_2 = kmeans_model.predict(&new_sample_2)[0]; + assert_eq!(new_sample_label_2, cluster_1_members[0]); + } + + #[test] + fn test_k_means_fit_k_equals_m() { + // Test case where k (number of clusters) equals m (number of samples) + let (x, _) = create_test_data(); // 5 samples + let k = 5; // 5 clusters + let max_iter = 10; + let tol = 1e-6; + + let (kmeans_model, labels) = KMeans::fit(&x, k, max_iter, tol); + + assert_eq!(kmeans_model.centroids.rows(), k); + assert_eq!(labels.len(), x.rows()); + + // Each sample should be its own cluster, so labels should be unique + let mut sorted_labels = labels.clone(); + sorted_labels.sort_unstable(); + assert_eq!(sorted_labels, vec![0, 1, 2, 3, 4]); + } + + #[test] + #[should_panic(expected = "k must be ≤ number of samples")] + fn test_k_means_fit_k_greater_than_m() { + let (x, _) = create_test_data(); // 5 samples + let k = 6; // k > m + let max_iter = 10; + let tol = 1e-6; + + let (_kmeans_model, _labels) = KMeans::fit(&x, k, max_iter, tol); + } + + #[test] + fn test_k_means_fit_single_cluster() { + // Test with k=1 + let x = create_simple_integer_data(); // Use integer data + let k = 1; + let max_iter = 100; + let tol = 1e-6; // Reset tolerance + + let (kmeans_model, labels) = KMeans::fit(&x, k, max_iter, tol); + + assert_eq!(kmeans_model.centroids.rows(), 1); + assert_eq!(labels.len(), x.rows()); + + // All labels should be 0 + assert!(labels.iter().all(|&l| l == 0)); + + // Centroid should be the mean of all data points + let expected_centroid_x = x.column(0).iter().sum::() / x.rows() as f64; + let expected_centroid_y = x.column(1).iter().sum::() / x.rows() as f64; + + assert!((kmeans_model.centroids[(0, 0)] - expected_centroid_x).abs() < 1e-9); + assert!((kmeans_model.centroids[(0, 1)] - expected_centroid_y).abs() < 1e-9); + } + + #[test] + fn test_k_means_predict_empty_matrix() { + let (x, k) = create_test_data(); + let max_iter = 10; + let tol = 1e-6; + let (kmeans_model, _labels) = KMeans::fit(&x, k, max_iter, tol); + + // Create a 0x0 matrix. This is allowed by Matrix constructor. + let empty_x = FloatMatrix::from_rows_vec(vec![], 0, 0); + let predicted_labels = kmeans_model.predict(&empty_x); + assert!(predicted_labels.is_empty()); + } + + #[test] + fn test_k_means_predict_single_sample() { + let (x, k) = create_test_data(); + let max_iter = 10; + let tol = 1e-6; + let (kmeans_model, _labels) = KMeans::fit(&x, k, max_iter, tol); + + let single_sample = FloatMatrix::from_rows_vec(vec![1.1, 1.2], 1, 2); + let predicted_label = kmeans_model.predict(&single_sample); + assert_eq!(predicted_label.len(), 1); + assert!(predicted_label[0] < k); + } +}