From 4648800a09b81c2e82069235c3e18c55f4afafc1 Mon Sep 17 00:00:00 2001 From: Palash Tyagi <23239946+Magnus167@users.noreply.github.com> Date: Sun, 6 Jul 2025 21:16:57 +0100 Subject: [PATCH] fixed incorrectly commited file --- src/compute/models/k_means.rs | 176 ++++++++++++++++++---------------- 1 file changed, 92 insertions(+), 84 deletions(-) diff --git a/src/compute/models/k_means.rs b/src/compute/models/k_means.rs index 716114f..7b7fcf0 100644 --- a/src/compute/models/k_means.rs +++ b/src/compute/models/k_means.rs @@ -1,105 +1,113 @@ use crate::matrix::Matrix; -use std::collections::HashMap; +use rand::seq::SliceRandom; -pub struct GaussianNB { - classes: Vec, // distinct labels - priors: Vec, // P(class) - means: Vec>, - variances: Vec>, - eps: f64, // var-smoothing +pub struct KMeans { + pub centroids: Matrix, // (k, n_features) } -impl GaussianNB { - pub fn new(var_smoothing: f64) -> Self { - Self { - classes: vec![], - priors: vec![], - means: vec![], - variances: vec![], - eps: var_smoothing, - } - } - - pub fn fit(&mut self, x: &Matrix, y: &Matrix) { +impl KMeans { + /// Fit with k clusters. + pub fn fit(x: &Matrix, k: usize, max_iter: usize, tol: f64) -> (Self, Vec) { let m = x.rows(); let n = x.cols(); - assert_eq!(y.rows(), m); - assert_eq!(y.cols(), 1); + assert!(k <= m, "k must be ≤ number of samples"); - // ----- group samples by label ----- - let mut groups: HashMap> = HashMap::new(); - for i in 0..m { - groups.entry(y[(i, 0)] as i64).or_default().push(i); + // ----- initialise centroids: pick k distinct rows at random ----- + let mut rng = rand::rng(); + let mut indices: Vec = (0..m).collect(); + indices.shuffle(&mut rng); + let mut centroids = Matrix::zeros(k, n); + for (c, &i) in indices[..k].iter().enumerate() { + for j in 0..n { + centroids[(c, j)] = x[(i, j)]; + } } - self.classes = groups.keys().cloned().map(|v| v as f64).collect::>(); - self.classes.sort_by(|a, b| a.partial_cmp(b).unwrap()); - - self.priors.clear(); - self.means.clear(); - self.variances.clear(); - - for &c in &self.classes { - let idx = &groups[&(c as i64)]; - let count = idx.len(); - self.priors.push(count as f64 / m as f64); - - let mut mean = Matrix::zeros(1, n); - let mut var = Matrix::zeros(1, n); - - // mean - for &i in idx { - for j in 0..n { - mean[(0, j)] += x[(i, j)]; + let mut labels = vec![0usize; m]; + for _ in 0..max_iter { + // ----- assignment step ----- + let mut changed = false; + for i in 0..m { + let mut best = 0usize; + let mut best_dist = f64::MAX; + for c in 0..k { + let mut dist = 0.0; + for j in 0..n { + let d = x[(i, j)] - centroids[(c, j)]; + dist += d * d; + } + if dist < best_dist { + best_dist = dist; + best = c; + } + } + if labels[i] != best { + labels[i] = best; + changed = true; } } - for j in 0..n { - mean[(0, j)] /= count as f64; - } - // variance - for &i in idx { + // ----- update step ----- + let mut counts = vec![0usize; k]; + let mut centroids = Matrix::zeros(k, n); + for i in 0..m { + let c = labels[i]; + counts[c] += 1; for j in 0..n { - let d = x[(i, j)] - mean[(0, j)]; - var[(0, j)] += d * d; + centroids[(c, j)] += x[(i, j)]; } } - for j in 0..n { - var[(0, j)] = var[(0, j)] / count as f64 + self.eps; - } - - self.means.push(mean); - self.variances.push(var); - } - } - - /// Return class labels (shape m×1) for samples in X. - pub fn predict(&self, x: &Matrix) -> Matrix { - let m = x.rows(); - let k = self.classes.len(); - let n = x.cols(); - let mut preds = Matrix::zeros(m, 1); - let ln_2pi = (2.0 * std::f64::consts::PI).ln(); - - for i in 0..m { - let mut best_class = 0usize; - let mut best_log_prob = f64::NEG_INFINITY; for c in 0..k { - // log P(y=c) + Σ log N(x_j | μ, σ²) - let mut log_prob = self.priors[c].ln(); - for j in 0..n { - let mean = self.means[c][(0, j)]; - let var = self.variances[c][(0, j)]; - let diff = x[(i, j)] - mean; - log_prob += -0.5 * (diff * diff / var + var.ln() + ln_2pi); - } - if log_prob > best_log_prob { - best_log_prob = log_prob; - best_class = c; + if counts[c] > 0 { + for j in 0..n { + centroids[(c, j)] /= counts[c] as f64; + } + } + } + + // ----- convergence test ----- + if !changed { + break; // assignments stable + } + if tol > 0.0 { + // optional centroid-shift tolerance + let mut shift: f64 = 0.0; + for c in 0..k { + for j in 0..n { + let d = centroids[(c, j)] - centroids[(c, j)]; // previous stored? + shift += d * d; + } + } + if shift.sqrt() < tol { + break; } } - preds[(i, 0)] = self.classes[best_class]; } - preds + (Self { centroids }, labels) + } + + /// Predict nearest centroid for each sample. + pub fn predict(&self, x: &Matrix) -> Vec { + let m = x.rows(); + let k = self.centroids.rows(); + let n = x.cols(); + let mut labels = vec![0usize; m]; + for i in 0..m { + let mut best = 0usize; + let mut best_dist = f64::MAX; + for c in 0..k { + let mut dist = 0.0; + for j in 0..n { + let d = x[(i, j)] - self.centroids[(c, j)]; + dist += d * d; + } + if dist < best_dist { + best_dist = dist; + best = c; + } + } + labels[i] = best; + } + labels } }