diff --git a/src/compute/models/logreg.rs b/src/compute/models/logreg.rs index 1bdaa86..473d5e9 100644 --- a/src/compute/models/logreg.rs +++ b/src/compute/models/logreg.rs @@ -34,3 +34,22 @@ impl LogReg { self.predict_proba(x).map(|p| if p >= 0.5 { 1.0 } else { 0.0 }) } } + + +mod tests { + use super::LogReg; + use crate::matrix::Matrix; + + #[test] + fn test_logreg_fit_predict() { + let x = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 4, 1); + let y = Matrix::from_vec(vec![0.0, 0.0, 1.0, 1.0], 4, 1); + let mut model = LogReg::new(1); + model.fit(&x, &y, 0.01, 10000); + let preds = model.predict(&x); + assert_eq!(preds[(0, 0)], 0.0); + assert_eq!(preds[(1, 0)], 0.0); + assert_eq!(preds[(2, 0)], 1.0); + assert_eq!(preds[(3, 0)], 1.0); + } +} \ No newline at end of file