mirror of
https://github.com/Magnus167/rustframe.git
synced 2025-08-20 04:00:01 +00:00
Add unit tests for logistic regression fit and predict methods
This commit is contained in:
parent
4ddacdfd21
commit
54a266b630
@ -34,3 +34,22 @@ impl LogReg {
|
|||||||
self.predict_proba(x).map(|p| if p >= 0.5 { 1.0 } else { 0.0 })
|
self.predict_proba(x).map(|p| if p >= 0.5 { 1.0 } else { 0.0 })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
mod tests {
|
||||||
|
use super::LogReg;
|
||||||
|
use crate::matrix::Matrix;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_logreg_fit_predict() {
|
||||||
|
let x = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 4, 1);
|
||||||
|
let y = Matrix::from_vec(vec![0.0, 0.0, 1.0, 1.0], 4, 1);
|
||||||
|
let mut model = LogReg::new(1);
|
||||||
|
model.fit(&x, &y, 0.01, 10000);
|
||||||
|
let preds = model.predict(&x);
|
||||||
|
assert_eq!(preds[(0, 0)], 0.0);
|
||||||
|
assert_eq!(preds[(1, 0)], 0.0);
|
||||||
|
assert_eq!(preds[(2, 0)], 1.0);
|
||||||
|
assert_eq!(preds[(3, 0)], 1.0);
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user