diff --git a/src/compute/models/logreg.rs b/src/compute/models/logreg.rs new file mode 100644 index 0000000..1bdaa86 --- /dev/null +++ b/src/compute/models/logreg.rs @@ -0,0 +1,36 @@ +use crate::matrix::{Matrix, SeriesOps}; +use crate::compute::activations::sigmoid; + +pub struct LogReg { + w: Matrix, + b: f64, +} + +impl LogReg { + pub fn new(n_features: usize) -> Self { + Self { + w: Matrix::zeros(n_features, 1), + b: 0.0, + } + } + + pub fn predict_proba(&self, x: &Matrix) -> Matrix { + sigmoid(&(x.dot(&self.w) + self.b)) // σ(Xw + b) + } + + pub fn fit(&mut self, x: &Matrix, y: &Matrix, lr: f64, epochs: usize) { + let m = x.rows() as f64; + for _ in 0..epochs { + let p = self.predict_proba(x); // shape (m,1) + let err = &p - y; // derivative of BCE wrt pre-sigmoid + let grad_w = x.transpose().dot(&err) / m; + let grad_b = err.sum_vertical().iter().sum::() / m; + self.w = &self.w - &(grad_w * lr); + self.b -= lr * grad_b; + } + } + + pub fn predict(&self, x: &Matrix) -> Matrix { + self.predict_proba(x).map(|p| if p >= 0.5 { 1.0 } else { 0.0 }) + } +}