mirror of
https://github.com/Magnus167/rustframe.git
synced 2025-08-20 04:00:01 +00:00
Refactor LogReg implementation for improved readability by adjusting formatting and organizing imports
This commit is contained in:
parent
2ca496cfd1
commit
1c8fcc0bad
@ -1,5 +1,5 @@
|
|||||||
use crate::matrix::{Matrix, SeriesOps};
|
|
||||||
use crate::compute::activations::sigmoid;
|
use crate::compute::activations::sigmoid;
|
||||||
|
use crate::matrix::{Matrix, SeriesOps};
|
||||||
|
|
||||||
pub struct LogReg {
|
pub struct LogReg {
|
||||||
w: Matrix<f64>,
|
w: Matrix<f64>,
|
||||||
@ -15,14 +15,14 @@ impl LogReg {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn predict_proba(&self, x: &Matrix<f64>) -> Matrix<f64> {
|
pub fn predict_proba(&self, x: &Matrix<f64>) -> Matrix<f64> {
|
||||||
sigmoid(&(x.dot(&self.w) + self.b)) // σ(Xw + b)
|
sigmoid(&(x.dot(&self.w) + self.b)) // σ(Xw + b)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn fit(&mut self, x: &Matrix<f64>, y: &Matrix<f64>, lr: f64, epochs: usize) {
|
pub fn fit(&mut self, x: &Matrix<f64>, y: &Matrix<f64>, lr: f64, epochs: usize) {
|
||||||
let m = x.rows() as f64;
|
let m = x.rows() as f64;
|
||||||
for _ in 0..epochs {
|
for _ in 0..epochs {
|
||||||
let p = self.predict_proba(x); // shape (m,1)
|
let p = self.predict_proba(x); // shape (m,1)
|
||||||
let err = &p - y; // derivative of BCE wrt pre-sigmoid
|
let err = &p - y; // derivative of BCE wrt pre-sigmoid
|
||||||
let grad_w = x.transpose().dot(&err) / m;
|
let grad_w = x.transpose().dot(&err) / m;
|
||||||
let grad_b = err.sum_vertical().iter().sum::<f64>() / m;
|
let grad_b = err.sum_vertical().iter().sum::<f64>() / m;
|
||||||
self.w = &self.w - &(grad_w * lr);
|
self.w = &self.w - &(grad_w * lr);
|
||||||
@ -31,14 +31,14 @@ impl LogReg {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn predict(&self, x: &Matrix<f64>) -> Matrix<f64> {
|
pub fn predict(&self, x: &Matrix<f64>) -> Matrix<f64> {
|
||||||
self.predict_proba(x).map(|p| if p >= 0.5 { 1.0 } else { 0.0 })
|
self.predict_proba(x)
|
||||||
|
.map(|p| if p >= 0.5 { 1.0 } else { 0.0 })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::LogReg;
|
use super::*;
|
||||||
use crate::matrix::Matrix;
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_logreg_fit_predict() {
|
fn test_logreg_fit_predict() {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user