fixed incorrectly commited file

This commit is contained in:
Palash Tyagi 2025-07-06 21:16:57 +01:00
parent 96f434bf94
commit 4648800a09

View File

@ -1,105 +1,113 @@
use crate::matrix::Matrix; use crate::matrix::Matrix;
use std::collections::HashMap; use rand::seq::SliceRandom;
pub struct GaussianNB { pub struct KMeans {
classes: Vec<f64>, // distinct labels pub centroids: Matrix<f64>, // (k, n_features)
priors: Vec<f64>, // P(class)
means: Vec<Matrix<f64>>,
variances: Vec<Matrix<f64>>,
eps: f64, // var-smoothing
} }
impl GaussianNB { impl KMeans {
pub fn new(var_smoothing: f64) -> Self { /// Fit with k clusters.
Self { pub fn fit(x: &Matrix<f64>, k: usize, max_iter: usize, tol: f64) -> (Self, Vec<usize>) {
classes: vec![],
priors: vec![],
means: vec![],
variances: vec![],
eps: var_smoothing,
}
}
pub fn fit(&mut self, x: &Matrix<f64>, y: &Matrix<f64>) {
let m = x.rows(); let m = x.rows();
let n = x.cols(); let n = x.cols();
assert_eq!(y.rows(), m); assert!(k <= m, "k must be ≤ number of samples");
assert_eq!(y.cols(), 1);
// ----- group samples by label ----- // ----- initialise centroids: pick k distinct rows at random -----
let mut groups: HashMap<i64, Vec<usize>> = HashMap::new(); let mut rng = rand::rng();
for i in 0..m { let mut indices: Vec<usize> = (0..m).collect();
groups.entry(y[(i, 0)] as i64).or_default().push(i); indices.shuffle(&mut rng);
let mut centroids = Matrix::zeros(k, n);
for (c, &i) in indices[..k].iter().enumerate() {
for j in 0..n {
centroids[(c, j)] = x[(i, j)];
}
} }
self.classes = groups.keys().cloned().map(|v| v as f64).collect::<Vec<_>>(); let mut labels = vec![0usize; m];
self.classes.sort_by(|a, b| a.partial_cmp(b).unwrap()); for _ in 0..max_iter {
// ----- assignment step -----
self.priors.clear(); let mut changed = false;
self.means.clear(); for i in 0..m {
self.variances.clear(); let mut best = 0usize;
let mut best_dist = f64::MAX;
for &c in &self.classes { for c in 0..k {
let idx = &groups[&(c as i64)]; let mut dist = 0.0;
let count = idx.len(); for j in 0..n {
self.priors.push(count as f64 / m as f64); let d = x[(i, j)] - centroids[(c, j)];
dist += d * d;
let mut mean = Matrix::zeros(1, n); }
let mut var = Matrix::zeros(1, n); if dist < best_dist {
best_dist = dist;
// mean best = c;
for &i in idx { }
for j in 0..n { }
mean[(0, j)] += x[(i, j)]; if labels[i] != best {
labels[i] = best;
changed = true;
} }
} }
for j in 0..n {
mean[(0, j)] /= count as f64;
}
// variance // ----- update step -----
for &i in idx { let mut counts = vec![0usize; k];
let mut centroids = Matrix::zeros(k, n);
for i in 0..m {
let c = labels[i];
counts[c] += 1;
for j in 0..n { for j in 0..n {
let d = x[(i, j)] - mean[(0, j)]; centroids[(c, j)] += x[(i, j)];
var[(0, j)] += d * d;
} }
} }
for j in 0..n {
var[(0, j)] = var[(0, j)] / count as f64 + self.eps;
}
self.means.push(mean);
self.variances.push(var);
}
}
/// Return class labels (shape m×1) for samples in X.
pub fn predict(&self, x: &Matrix<f64>) -> Matrix<f64> {
let m = x.rows();
let k = self.classes.len();
let n = x.cols();
let mut preds = Matrix::zeros(m, 1);
let ln_2pi = (2.0 * std::f64::consts::PI).ln();
for i in 0..m {
let mut best_class = 0usize;
let mut best_log_prob = f64::NEG_INFINITY;
for c in 0..k { for c in 0..k {
// log P(y=c) + Σ log N(x_j | μ, σ²) if counts[c] > 0 {
let mut log_prob = self.priors[c].ln(); for j in 0..n {
for j in 0..n { centroids[(c, j)] /= counts[c] as f64;
let mean = self.means[c][(0, j)]; }
let var = self.variances[c][(0, j)]; }
let diff = x[(i, j)] - mean; }
log_prob += -0.5 * (diff * diff / var + var.ln() + ln_2pi);
} // ----- convergence test -----
if log_prob > best_log_prob { if !changed {
best_log_prob = log_prob; break; // assignments stable
best_class = c; }
if tol > 0.0 {
// optional centroid-shift tolerance
let mut shift: f64 = 0.0;
for c in 0..k {
for j in 0..n {
let d = centroids[(c, j)] - centroids[(c, j)]; // previous stored?
shift += d * d;
}
}
if shift.sqrt() < tol {
break;
} }
} }
preds[(i, 0)] = self.classes[best_class];
} }
preds (Self { centroids }, labels)
}
/// Predict nearest centroid for each sample.
pub fn predict(&self, x: &Matrix<f64>) -> Vec<usize> {
let m = x.rows();
let k = self.centroids.rows();
let n = x.cols();
let mut labels = vec![0usize; m];
for i in 0..m {
let mut best = 0usize;
let mut best_dist = f64::MAX;
for c in 0..k {
let mut dist = 0.0;
for j in 0..n {
let d = x[(i, j)] - self.centroids[(c, j)];
dist += d * d;
}
if dist < best_dist {
best_dist = dist;
best = c;
}
}
labels[i] = best;
}
labels
} }
} }