diff --git a/src/compute/models/linreg.rs b/src/compute/models/linreg.rs new file mode 100644 index 0000000..74e5e44 --- /dev/null +++ b/src/compute/models/linreg.rs @@ -0,0 +1,35 @@ +use crate::matrix::{Matrix, SeriesOps}; + +pub struct LinReg { + w: Matrix, // shape (n_features, 1) + b: f64, +} + +impl LinReg { + pub fn new(n_features: usize) -> Self { + Self { + w: Matrix::from_vec(vec![0.0; n_features], n_features, 1), + b: 0.0, + } + } + + pub fn predict(&self, x: &Matrix) -> Matrix { + // X.dot(w) + b + x.dot(&self.w) + self.b + } + + pub fn fit(&mut self, x: &Matrix, y: &Matrix, lr: f64, epochs: usize) { + let m = x.rows() as f64; + for _ in 0..epochs { + let y_hat = self.predict(x); + let err = &y_hat - y; // shape (m,1) + + // grads + let grad_w = x.transpose().dot(&err) * (2.0 / m); // (n,1) + let grad_b = (2.0 / m) * err.sum_vertical().iter().sum::(); + // update + self.w = &self.w - &(grad_w * lr); + self.b -= lr * grad_b; + } + } +}