mirror of
https://github.com/Magnus167/rustframe.git
synced 2025-08-20 04:00:01 +00:00
Refactor KMeans fit and predict methods for improved clarity and performance
This commit is contained in:
parent
049dd02c1a
commit
12a72317e4
@ -1,4 +1,6 @@
|
|||||||
use crate::matrix::Matrix;
|
use crate::matrix::Matrix;
|
||||||
|
use crate::matrix::{FloatMatrix, SeriesOps};
|
||||||
|
use rand::rng; // Changed from rand::thread_rng
|
||||||
use rand::seq::SliceRandom;
|
use rand::seq::SliceRandom;
|
||||||
|
|
||||||
pub struct KMeans {
|
pub struct KMeans {
|
||||||
@ -13,7 +15,7 @@ impl KMeans {
|
|||||||
assert!(k <= m, "k must be ≤ number of samples");
|
assert!(k <= m, "k must be ≤ number of samples");
|
||||||
|
|
||||||
// ----- initialise centroids: pick k distinct rows at random -----
|
// ----- initialise centroids: pick k distinct rows at random -----
|
||||||
let mut rng = rand::rng();
|
let mut rng = rng(); // Changed from thread_rng()
|
||||||
let mut indices: Vec<usize> = (0..m).collect();
|
let mut indices: Vec<usize> = (0..m).collect();
|
||||||
indices.shuffle(&mut rng);
|
indices.shuffle(&mut rng);
|
||||||
let mut centroids = Matrix::zeros(k, n);
|
let mut centroids = Matrix::zeros(k, n);
|
||||||
@ -24,20 +26,29 @@ impl KMeans {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let mut labels = vec![0usize; m];
|
let mut labels = vec![0usize; m];
|
||||||
for _ in 0..max_iter {
|
let mut old_centroids = centroids.clone(); // Store initial centroids for first iteration's convergence check
|
||||||
|
|
||||||
|
for _iter in 0..max_iter {
|
||||||
|
// Renamed loop variable to _iter for clarity
|
||||||
// ----- assignment step -----
|
// ----- assignment step -----
|
||||||
let mut changed = false;
|
let mut changed = false;
|
||||||
for i in 0..m {
|
for i in 0..m {
|
||||||
|
let sample_row = x.row(i);
|
||||||
|
let sample_matrix = FloatMatrix::from_rows_vec(sample_row, 1, n);
|
||||||
|
|
||||||
let mut best = 0usize;
|
let mut best = 0usize;
|
||||||
let mut best_dist = f64::MAX;
|
let mut best_dist_sq = f64::MAX;
|
||||||
|
|
||||||
for c in 0..k {
|
for c in 0..k {
|
||||||
let mut dist = 0.0;
|
let centroid_row = old_centroids.row(c); // Use old_centroids for distance calculation
|
||||||
for j in 0..n {
|
let centroid_matrix = FloatMatrix::from_rows_vec(centroid_row, 1, n);
|
||||||
let d = x[(i, j)] - centroids[(c, j)];
|
|
||||||
dist += d * d;
|
let diff = &sample_matrix - ¢roid_matrix;
|
||||||
}
|
let sq_diff = &diff * &diff;
|
||||||
if dist < best_dist {
|
let dist_sq = sq_diff.sum_horizontal()[0];
|
||||||
best_dist = dist;
|
|
||||||
|
if dist_sq < best_dist_sq {
|
||||||
|
best_dist_sq = dist_sq;
|
||||||
best = c;
|
best = c;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -49,18 +60,18 @@ impl KMeans {
|
|||||||
|
|
||||||
// ----- update step -----
|
// ----- update step -----
|
||||||
let mut counts = vec![0usize; k];
|
let mut counts = vec![0usize; k];
|
||||||
let mut centroids = Matrix::zeros(k, n);
|
let mut new_centroids = Matrix::zeros(k, n); // New centroids for this iteration
|
||||||
for i in 0..m {
|
for i in 0..m {
|
||||||
let c = labels[i];
|
let c = labels[i];
|
||||||
counts[c] += 1;
|
counts[c] += 1;
|
||||||
for j in 0..n {
|
for j in 0..n {
|
||||||
centroids[(c, j)] += x[(i, j)];
|
new_centroids[(c, j)] += x[(i, j)];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for c in 0..k {
|
for c in 0..k {
|
||||||
if counts[c] > 0 {
|
if counts[c] > 0 {
|
||||||
for j in 0..n {
|
for j in 0..n {
|
||||||
centroids[(c, j)] /= counts[c] as f64;
|
new_centroids[(c, j)] /= counts[c] as f64;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -71,19 +82,22 @@ impl KMeans {
|
|||||||
}
|
}
|
||||||
if tol > 0.0 {
|
if tol > 0.0 {
|
||||||
// optional centroid-shift tolerance
|
// optional centroid-shift tolerance
|
||||||
let mut shift: f64 = 0.0;
|
let diff = &new_centroids - &old_centroids; // Calculate difference between new and old centroids
|
||||||
for c in 0..k {
|
let sq_diff = &diff * &diff;
|
||||||
for j in 0..n {
|
let shift = sq_diff.data().iter().sum::<f64>().sqrt(); // Sum all squared differences
|
||||||
let d = centroids[(c, j)] - centroids[(c, j)]; // previous stored?
|
|
||||||
shift += d * d;
|
if shift < tol {
|
||||||
}
|
|
||||||
}
|
|
||||||
if shift.sqrt() < tol {
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
old_centroids = new_centroids; // Update old_centroids for next iteration
|
||||||
}
|
}
|
||||||
(Self { centroids }, labels)
|
(
|
||||||
|
Self {
|
||||||
|
centroids: old_centroids,
|
||||||
|
},
|
||||||
|
labels,
|
||||||
|
) // Return the final centroids
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Predict nearest centroid for each sample.
|
/// Predict nearest centroid for each sample.
|
||||||
@ -91,18 +105,29 @@ impl KMeans {
|
|||||||
let m = x.rows();
|
let m = x.rows();
|
||||||
let k = self.centroids.rows();
|
let k = self.centroids.rows();
|
||||||
let n = x.cols();
|
let n = x.cols();
|
||||||
|
|
||||||
|
if m == 0 {
|
||||||
|
// Handle empty input matrix
|
||||||
|
return Vec::new();
|
||||||
|
}
|
||||||
|
|
||||||
let mut labels = vec![0usize; m];
|
let mut labels = vec![0usize; m];
|
||||||
for i in 0..m {
|
for i in 0..m {
|
||||||
|
let sample_row = x.row(i);
|
||||||
|
let sample_matrix = FloatMatrix::from_rows_vec(sample_row, 1, n);
|
||||||
|
|
||||||
let mut best = 0usize;
|
let mut best = 0usize;
|
||||||
let mut best_dist = f64::MAX;
|
let mut best_dist_sq = f64::MAX;
|
||||||
for c in 0..k {
|
for c in 0..k {
|
||||||
let mut dist = 0.0;
|
let centroid_row = self.centroids.row(c);
|
||||||
for j in 0..n {
|
let centroid_matrix = FloatMatrix::from_rows_vec(centroid_row, 1, n);
|
||||||
let d = x[(i, j)] - self.centroids[(c, j)];
|
|
||||||
dist += d * d;
|
let diff = &sample_matrix - ¢roid_matrix;
|
||||||
}
|
let sq_diff = &diff * &diff;
|
||||||
if dist < best_dist {
|
let dist_sq = sq_diff.sum_horizontal()[0];
|
||||||
best_dist = dist;
|
|
||||||
|
if dist_sq < best_dist_sq {
|
||||||
|
best_dist_sq = dist_sq;
|
||||||
best = c;
|
best = c;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -111,3 +136,162 @@ impl KMeans {
|
|||||||
labels
|
labels
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::matrix::FloatMatrix;
|
||||||
|
|
||||||
|
fn create_test_data() -> (FloatMatrix, usize) {
|
||||||
|
// Simple 2D data for testing K-Means
|
||||||
|
// Cluster 1: (1,1), (1.5,1.5)
|
||||||
|
// Cluster 2: (5,8), (8,8), (6,7)
|
||||||
|
let data = vec![
|
||||||
|
1.0, 1.0, // Sample 0
|
||||||
|
1.5, 1.5, // Sample 1
|
||||||
|
5.0, 8.0, // Sample 2
|
||||||
|
8.0, 8.0, // Sample 3
|
||||||
|
6.0, 7.0, // Sample 4
|
||||||
|
];
|
||||||
|
let x = FloatMatrix::from_rows_vec(data, 5, 2);
|
||||||
|
let k = 2;
|
||||||
|
(x, k)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper for single cluster test with exact mean
|
||||||
|
fn create_simple_integer_data() -> FloatMatrix {
|
||||||
|
// Data points: (1,1), (2,2), (3,3)
|
||||||
|
FloatMatrix::from_rows_vec(vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0], 3, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_k_means_fit_predict_basic() {
|
||||||
|
let (x, k) = create_test_data();
|
||||||
|
let max_iter = 100;
|
||||||
|
let tol = 1e-6;
|
||||||
|
|
||||||
|
let (kmeans_model, labels) = KMeans::fit(&x, k, max_iter, tol);
|
||||||
|
|
||||||
|
// Assertions for fit
|
||||||
|
assert_eq!(kmeans_model.centroids.rows(), k);
|
||||||
|
assert_eq!(kmeans_model.centroids.cols(), x.cols());
|
||||||
|
assert_eq!(labels.len(), x.rows());
|
||||||
|
|
||||||
|
// Check if labels are within expected range (0 to k-1)
|
||||||
|
for &label in &labels {
|
||||||
|
assert!(label < k);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Predict with the same data
|
||||||
|
let predicted_labels = kmeans_model.predict(&x);
|
||||||
|
|
||||||
|
// The exact labels might vary due to random initialization,
|
||||||
|
// but the clustering should be consistent.
|
||||||
|
// We expect two clusters. Let's check if samples 0,1 are in one cluster
|
||||||
|
// and samples 2,3,4 are in another.
|
||||||
|
let cluster_0_members = vec![labels[0], labels[1]];
|
||||||
|
let cluster_1_members = vec![labels[2], labels[3], labels[4]];
|
||||||
|
|
||||||
|
// All members of cluster 0 should have the same label
|
||||||
|
assert_eq!(cluster_0_members[0], cluster_0_members[1]);
|
||||||
|
// All members of cluster 1 should have the same label
|
||||||
|
assert_eq!(cluster_1_members[0], cluster_1_members[1]);
|
||||||
|
assert_eq!(cluster_1_members[0], cluster_1_members[2]);
|
||||||
|
// The two clusters should have different labels
|
||||||
|
assert_ne!(cluster_0_members[0], cluster_1_members[0]);
|
||||||
|
|
||||||
|
// Check predicted labels are consistent with fitted labels
|
||||||
|
assert_eq!(labels, predicted_labels);
|
||||||
|
|
||||||
|
// Test with a new sample
|
||||||
|
let new_sample_data = vec![1.2, 1.3]; // Should be close to cluster 0
|
||||||
|
let new_sample = FloatMatrix::from_rows_vec(new_sample_data, 1, 2);
|
||||||
|
let new_sample_label = kmeans_model.predict(&new_sample)[0];
|
||||||
|
assert_eq!(new_sample_label, cluster_0_members[0]);
|
||||||
|
|
||||||
|
let new_sample_data_2 = vec![7.0, 7.5]; // Should be close to cluster 1
|
||||||
|
let new_sample_2 = FloatMatrix::from_rows_vec(new_sample_data_2, 1, 2);
|
||||||
|
let new_sample_label_2 = kmeans_model.predict(&new_sample_2)[0];
|
||||||
|
assert_eq!(new_sample_label_2, cluster_1_members[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_k_means_fit_k_equals_m() {
|
||||||
|
// Test case where k (number of clusters) equals m (number of samples)
|
||||||
|
let (x, _) = create_test_data(); // 5 samples
|
||||||
|
let k = 5; // 5 clusters
|
||||||
|
let max_iter = 10;
|
||||||
|
let tol = 1e-6;
|
||||||
|
|
||||||
|
let (kmeans_model, labels) = KMeans::fit(&x, k, max_iter, tol);
|
||||||
|
|
||||||
|
assert_eq!(kmeans_model.centroids.rows(), k);
|
||||||
|
assert_eq!(labels.len(), x.rows());
|
||||||
|
|
||||||
|
// Each sample should be its own cluster, so labels should be unique
|
||||||
|
let mut sorted_labels = labels.clone();
|
||||||
|
sorted_labels.sort_unstable();
|
||||||
|
assert_eq!(sorted_labels, vec![0, 1, 2, 3, 4]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[should_panic(expected = "k must be ≤ number of samples")]
|
||||||
|
fn test_k_means_fit_k_greater_than_m() {
|
||||||
|
let (x, _) = create_test_data(); // 5 samples
|
||||||
|
let k = 6; // k > m
|
||||||
|
let max_iter = 10;
|
||||||
|
let tol = 1e-6;
|
||||||
|
|
||||||
|
let (_kmeans_model, _labels) = KMeans::fit(&x, k, max_iter, tol);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_k_means_fit_single_cluster() {
|
||||||
|
// Test with k=1
|
||||||
|
let x = create_simple_integer_data(); // Use integer data
|
||||||
|
let k = 1;
|
||||||
|
let max_iter = 100;
|
||||||
|
let tol = 1e-6; // Reset tolerance
|
||||||
|
|
||||||
|
let (kmeans_model, labels) = KMeans::fit(&x, k, max_iter, tol);
|
||||||
|
|
||||||
|
assert_eq!(kmeans_model.centroids.rows(), 1);
|
||||||
|
assert_eq!(labels.len(), x.rows());
|
||||||
|
|
||||||
|
// All labels should be 0
|
||||||
|
assert!(labels.iter().all(|&l| l == 0));
|
||||||
|
|
||||||
|
// Centroid should be the mean of all data points
|
||||||
|
let expected_centroid_x = x.column(0).iter().sum::<f64>() / x.rows() as f64;
|
||||||
|
let expected_centroid_y = x.column(1).iter().sum::<f64>() / x.rows() as f64;
|
||||||
|
|
||||||
|
assert!((kmeans_model.centroids[(0, 0)] - expected_centroid_x).abs() < 1e-9);
|
||||||
|
assert!((kmeans_model.centroids[(0, 1)] - expected_centroid_y).abs() < 1e-9);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_k_means_predict_empty_matrix() {
|
||||||
|
let (x, k) = create_test_data();
|
||||||
|
let max_iter = 10;
|
||||||
|
let tol = 1e-6;
|
||||||
|
let (kmeans_model, _labels) = KMeans::fit(&x, k, max_iter, tol);
|
||||||
|
|
||||||
|
// Create a 0x0 matrix. This is allowed by Matrix constructor.
|
||||||
|
let empty_x = FloatMatrix::from_rows_vec(vec![], 0, 0);
|
||||||
|
let predicted_labels = kmeans_model.predict(&empty_x);
|
||||||
|
assert!(predicted_labels.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_k_means_predict_single_sample() {
|
||||||
|
let (x, k) = create_test_data();
|
||||||
|
let max_iter = 10;
|
||||||
|
let tol = 1e-6;
|
||||||
|
let (kmeans_model, _labels) = KMeans::fit(&x, k, max_iter, tol);
|
||||||
|
|
||||||
|
let single_sample = FloatMatrix::from_rows_vec(vec![1.1, 1.2], 1, 2);
|
||||||
|
let predicted_label = kmeans_model.predict(&single_sample);
|
||||||
|
assert_eq!(predicted_label.len(), 1);
|
||||||
|
assert!(predicted_label[0] < k);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user