Refactor LogReg implementation for improved readability by adjusting formatting and organizing imports

This commit is contained in:
Palash Tyagi 2025-07-06 19:17:03 +01:00
parent 2ca496cfd1
commit 1c8fcc0bad

View File

@ -1,5 +1,5 @@
use crate::matrix::{Matrix, SeriesOps};
use crate::compute::activations::sigmoid; use crate::compute::activations::sigmoid;
use crate::matrix::{Matrix, SeriesOps};
pub struct LogReg { pub struct LogReg {
w: Matrix<f64>, w: Matrix<f64>,
@ -15,14 +15,14 @@ impl LogReg {
} }
pub fn predict_proba(&self, x: &Matrix<f64>) -> Matrix<f64> { pub fn predict_proba(&self, x: &Matrix<f64>) -> Matrix<f64> {
sigmoid(&(x.dot(&self.w) + self.b)) // σ(Xw + b) sigmoid(&(x.dot(&self.w) + self.b)) // σ(Xw + b)
} }
pub fn fit(&mut self, x: &Matrix<f64>, y: &Matrix<f64>, lr: f64, epochs: usize) { pub fn fit(&mut self, x: &Matrix<f64>, y: &Matrix<f64>, lr: f64, epochs: usize) {
let m = x.rows() as f64; let m = x.rows() as f64;
for _ in 0..epochs { for _ in 0..epochs {
let p = self.predict_proba(x); // shape (m,1) let p = self.predict_proba(x); // shape (m,1)
let err = &p - y; // derivative of BCE wrt pre-sigmoid let err = &p - y; // derivative of BCE wrt pre-sigmoid
let grad_w = x.transpose().dot(&err) / m; let grad_w = x.transpose().dot(&err) / m;
let grad_b = err.sum_vertical().iter().sum::<f64>() / m; let grad_b = err.sum_vertical().iter().sum::<f64>() / m;
self.w = &self.w - &(grad_w * lr); self.w = &self.w - &(grad_w * lr);
@ -31,14 +31,14 @@ impl LogReg {
} }
pub fn predict(&self, x: &Matrix<f64>) -> Matrix<f64> { pub fn predict(&self, x: &Matrix<f64>) -> Matrix<f64> {
self.predict_proba(x).map(|p| if p >= 0.5 { 1.0 } else { 0.0 }) self.predict_proba(x)
.map(|p| if p >= 0.5 { 1.0 } else { 0.0 })
} }
} }
#[cfg(test)]
mod tests { mod tests {
use super::LogReg; use super::*;
use crate::matrix::Matrix;
#[test] #[test]
fn test_logreg_fit_predict() { fn test_logreg_fit_predict() {