use crate::matrix::Matrix; use crate::matrix::{FloatMatrix, SeriesOps}; use rand::rng; // Changed from rand::thread_rng use rand::seq::SliceRandom; pub struct KMeans { pub centroids: Matrix, // (k, n_features) } impl KMeans { /// Fit with k clusters. pub fn fit(x: &Matrix, k: usize, max_iter: usize, tol: f64) -> (Self, Vec) { let m = x.rows(); let n = x.cols(); assert!(k <= m, "k must be ≤ number of samples"); // ----- initialise centroids ----- let mut centroids = Matrix::zeros(k, n); if k == 1 { // For k=1, initialize the centroid to the mean of the data for j in 0..n { centroids[(0, j)] = x.column(j).iter().sum::() / m as f64; } } else { // For k > 1, pick k distinct rows at random let mut rng = rng(); // Changed from thread_rng() let mut indices: Vec = (0..m).collect(); indices.shuffle(&mut rng); for (c, &i) in indices[..k].iter().enumerate() { for j in 0..n { centroids[(c, j)] = x[(i, j)]; } } } let mut labels = vec![0usize; m]; let mut old_centroids = centroids.clone(); // Store initial centroids for first iteration's convergence check for _iter in 0..max_iter { // Renamed loop variable to _iter for clarity // ----- assignment step ----- let mut changed = false; for i in 0..m { let sample_row = x.row(i); let sample_matrix = FloatMatrix::from_rows_vec(sample_row, 1, n); let mut best = 0usize; let mut best_dist_sq = f64::MAX; for c in 0..k { let centroid_row = old_centroids.row(c); // Use old_centroids for distance calculation let centroid_matrix = FloatMatrix::from_rows_vec(centroid_row, 1, n); let diff = &sample_matrix - ¢roid_matrix; let sq_diff = &diff * &diff; let dist_sq = sq_diff.sum_horizontal()[0]; if dist_sq < best_dist_sq { best_dist_sq = dist_sq; best = c; } } if labels[i] != best { labels[i] = best; changed = true; } } // ----- update step ----- let mut counts = vec![0usize; k]; let mut new_centroids = Matrix::zeros(k, n); // New centroids for this iteration for i in 0..m { let c = labels[i]; counts[c] += 1; for j in 0..n { new_centroids[(c, j)] += x[(i, j)]; } } for c in 0..k { if counts[c] > 0 { for j in 0..n { new_centroids[(c, j)] /= counts[c] as f64; } } } // ----- convergence test ----- if !changed { break; // assignments stable } if tol > 0.0 { // optional centroid-shift tolerance let diff = &new_centroids - &old_centroids; // Calculate difference between new and old centroids let sq_diff = &diff * &diff; let shift = sq_diff.data().iter().sum::().sqrt(); // Sum all squared differences if shift < tol { break; } } old_centroids = new_centroids; // Update old_centroids for next iteration } ( Self { centroids: old_centroids, }, labels, ) // Return the final centroids } /// Predict nearest centroid for each sample. pub fn predict(&self, x: &Matrix) -> Vec { let m = x.rows(); let k = self.centroids.rows(); let n = x.cols(); if m == 0 { // Handle empty input matrix return Vec::new(); } let mut labels = vec![0usize; m]; for i in 0..m { let sample_row = x.row(i); let sample_matrix = FloatMatrix::from_rows_vec(sample_row, 1, n); let mut best = 0usize; let mut best_dist_sq = f64::MAX; for c in 0..k { let centroid_row = self.centroids.row(c); let centroid_matrix = FloatMatrix::from_rows_vec(centroid_row, 1, n); let diff = &sample_matrix - ¢roid_matrix; let sq_diff = &diff * &diff; let dist_sq = sq_diff.sum_horizontal()[0]; if dist_sq < best_dist_sq { best_dist_sq = dist_sq; best = c; } } labels[i] = best; } labels } } #[cfg(test)] mod tests { use super::*; use crate::matrix::FloatMatrix; fn create_test_data() -> (FloatMatrix, usize) { // Simple 2D data for testing K-Means // Cluster 1: (1,1), (1.5,1.5) // Cluster 2: (5,8), (8,8), (6,7) let data = vec![ 1.0, 1.0, // Sample 0 1.5, 1.5, // Sample 1 5.0, 8.0, // Sample 2 8.0, 8.0, // Sample 3 6.0, 7.0, // Sample 4 ]; let x = FloatMatrix::from_rows_vec(data, 5, 2); let k = 2; (x, k) } // Helper for single cluster test with exact mean fn create_simple_integer_data() -> FloatMatrix { // Data points: (1,1), (2,2), (3,3) FloatMatrix::from_rows_vec(vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0], 3, 2) } #[test] fn test_k_means_fit_predict_basic() { let (x, k) = create_test_data(); let max_iter = 100; let tol = 1e-6; let (kmeans_model, labels) = KMeans::fit(&x, k, max_iter, tol); // Assertions for fit assert_eq!(kmeans_model.centroids.rows(), k); assert_eq!(kmeans_model.centroids.cols(), x.cols()); assert_eq!(labels.len(), x.rows()); // Check if labels are within expected range (0 to k-1) for &label in &labels { assert!(label < k); } // Predict with the same data let predicted_labels = kmeans_model.predict(&x); // The exact labels might vary due to random initialization, // but the clustering should be consistent. // We expect two clusters. Let's check if samples 0,1 are in one cluster // and samples 2,3,4 are in another. let cluster_0_members = vec![labels[0], labels[1]]; let cluster_1_members = vec![labels[2], labels[3], labels[4]]; // All members of cluster 0 should have the same label assert_eq!(cluster_0_members[0], cluster_0_members[1]); // All members of cluster 1 should have the same label assert_eq!(cluster_1_members[0], cluster_1_members[1]); assert_eq!(cluster_1_members[0], cluster_1_members[2]); // The two clusters should have different labels assert_ne!(cluster_0_members[0], cluster_1_members[0]); // Check predicted labels are consistent with fitted labels assert_eq!(labels, predicted_labels); // Test with a new sample let new_sample_data = vec![1.2, 1.3]; // Should be close to cluster 0 let new_sample = FloatMatrix::from_rows_vec(new_sample_data, 1, 2); let new_sample_label = kmeans_model.predict(&new_sample)[0]; assert_eq!(new_sample_label, cluster_0_members[0]); let new_sample_data_2 = vec![7.0, 7.5]; // Should be close to cluster 1 let new_sample_2 = FloatMatrix::from_rows_vec(new_sample_data_2, 1, 2); let new_sample_label_2 = kmeans_model.predict(&new_sample_2)[0]; assert_eq!(new_sample_label_2, cluster_1_members[0]); } #[test] fn test_k_means_fit_k_equals_m() { // Test case where k (number of clusters) equals m (number of samples) let (x, _) = create_test_data(); // 5 samples let k = 5; // 5 clusters let max_iter = 10; let tol = 1e-6; let (kmeans_model, labels) = KMeans::fit(&x, k, max_iter, tol); assert_eq!(kmeans_model.centroids.rows(), k); assert_eq!(labels.len(), x.rows()); // Each sample should be its own cluster, so labels should be unique let mut sorted_labels = labels.clone(); sorted_labels.sort_unstable(); assert_eq!(sorted_labels, vec![0, 1, 2, 3, 4]); } #[test] #[should_panic(expected = "k must be ≤ number of samples")] fn test_k_means_fit_k_greater_than_m() { let (x, _) = create_test_data(); // 5 samples let k = 6; // k > m let max_iter = 10; let tol = 1e-6; let (_kmeans_model, _labels) = KMeans::fit(&x, k, max_iter, tol); } #[test] fn test_k_means_fit_single_cluster() { // Test with k=1 let x = create_simple_integer_data(); // Use integer data let k = 1; let max_iter = 100; let tol = 1e-6; // Reset tolerance let (kmeans_model, labels) = KMeans::fit(&x, k, max_iter, tol); assert_eq!(kmeans_model.centroids.rows(), 1); assert_eq!(labels.len(), x.rows()); // All labels should be 0 assert!(labels.iter().all(|&l| l == 0)); // Centroid should be the mean of all data points let expected_centroid_x = x.column(0).iter().sum::() / x.rows() as f64; let expected_centroid_y = x.column(1).iter().sum::() / x.rows() as f64; // Relax the assertion tolerance to match the algorithm's convergence tolerance assert!((kmeans_model.centroids[(0, 0)] - expected_centroid_x).abs() < 1e-6); assert!((kmeans_model.centroids[(0, 1)] - expected_centroid_y).abs() < 1e-6); } #[test] fn test_k_means_predict_empty_matrix() { let (x, k) = create_test_data(); let max_iter = 10; let tol = 1e-6; let (kmeans_model, _labels) = KMeans::fit(&x, k, max_iter, tol); // Create a 0x0 matrix. This is allowed by Matrix constructor. let empty_x = FloatMatrix::from_rows_vec(vec![], 0, 0); let predicted_labels = kmeans_model.predict(&empty_x); assert!(predicted_labels.is_empty()); } #[test] fn test_k_means_predict_single_sample() { let (x, k) = create_test_data(); let max_iter = 10; let tol = 1e-6; let (kmeans_model, _labels) = KMeans::fit(&x, k, max_iter, tol); let single_sample = FloatMatrix::from_rows_vec(vec![1.1, 1.2], 1, 2); let predicted_label = kmeans_model.predict(&single_sample); assert_eq!(predicted_label.len(), 1); assert!(predicted_label[0] < k); } }