From 1c8fcc0bad716aa1399a98d9d32ffde4f017db30 Mon Sep 17 00:00:00 2001 From: Palash Tyagi <23239946+Magnus167@users.noreply.github.com> Date: Sun, 6 Jul 2025 19:17:03 +0100 Subject: [PATCH] Refactor LogReg implementation for improved readability by adjusting formatting and organizing imports --- src/compute/models/logreg.rs | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/compute/models/logreg.rs b/src/compute/models/logreg.rs index 473d5e9..8821ee4 100644 --- a/src/compute/models/logreg.rs +++ b/src/compute/models/logreg.rs @@ -1,5 +1,5 @@ -use crate::matrix::{Matrix, SeriesOps}; use crate::compute::activations::sigmoid; +use crate::matrix::{Matrix, SeriesOps}; pub struct LogReg { w: Matrix, @@ -15,14 +15,14 @@ impl LogReg { } pub fn predict_proba(&self, x: &Matrix) -> Matrix { - sigmoid(&(x.dot(&self.w) + self.b)) // σ(Xw + b) + sigmoid(&(x.dot(&self.w) + self.b)) // σ(Xw + b) } pub fn fit(&mut self, x: &Matrix, y: &Matrix, lr: f64, epochs: usize) { let m = x.rows() as f64; for _ in 0..epochs { - let p = self.predict_proba(x); // shape (m,1) - let err = &p - y; // derivative of BCE wrt pre-sigmoid + let p = self.predict_proba(x); // shape (m,1) + let err = &p - y; // derivative of BCE wrt pre-sigmoid let grad_w = x.transpose().dot(&err) / m; let grad_b = err.sum_vertical().iter().sum::() / m; self.w = &self.w - &(grad_w * lr); @@ -31,14 +31,14 @@ impl LogReg { } pub fn predict(&self, x: &Matrix) -> Matrix { - self.predict_proba(x).map(|p| if p >= 0.5 { 1.0 } else { 0.0 }) + self.predict_proba(x) + .map(|p| if p >= 0.5 { 1.0 } else { 0.0 }) } } - +#[cfg(test)] mod tests { - use super::LogReg; - use crate::matrix::Matrix; + use super::*; #[test] fn test_logreg_fit_predict() { @@ -52,4 +52,4 @@ mod tests { assert_eq!(preds[(2, 0)], 1.0); assert_eq!(preds[(3, 0)], 1.0); } -} \ No newline at end of file +}