diff --git a/examples/logistic_regression.rs b/examples/logistic_regression.rs index 1015bef..652f811 100644 --- a/examples/logistic_regression.rs +++ b/examples/logistic_regression.rs @@ -71,3 +71,31 @@ fn purchase_prediction_example() { let p = model.predict_proba(&new_visit); println!("Prob of purchase for 4min/4pages: {:.2}", p[(0, 0)]); } + +#[test] +fn test_student_passing_example() { + let hours = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]; + let passed = vec![0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0]; + let x = Matrix::from_vec(hours.clone(), hours.len(), 1); + let y = Matrix::from_vec(passed.clone(), passed.len(), 1); + let mut model = LogReg::new(1); + model.fit(&x, &y, 0.1, 10000); + let preds = model.predict(&x); + for i in 0..y.rows() { + assert_eq!(preds[(i, 0)], passed[i]); + } +} + +#[test] +fn test_purchase_prediction_example() { + let raw_x = vec![1.0, 2.0, 3.0, 1.0, 2.0, 4.0, 5.0, 5.0, 3.5, 2.0, 6.0, 6.0]; + let bought = vec![0.0, 0.0, 0.0, 1.0, 0.0, 1.0]; + let x = Matrix::from_rows_vec(raw_x, 6, 2); + let y = Matrix::from_vec(bought.clone(), bought.len(), 1); + let mut model = LogReg::new(2); + model.fit(&x, &y, 0.05, 20000); + let preds = model.predict(&x); + for i in 0..y.rows() { + assert_eq!(preds[(i, 0)], bought[i]); + } +} \ No newline at end of file