From d0f9e80dfcc59afa0ecfa193b53d316d5f12a303 Mon Sep 17 00:00:00 2001 From: Palash Tyagi <23239946+Magnus167@users.noreply.github.com> Date: Sat, 26 Jul 2025 18:38:27 +0100 Subject: [PATCH] add test as examples --- examples/linear_regression.rs | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/examples/linear_regression.rs b/examples/linear_regression.rs index 0ad4afe..033d96a 100644 --- a/examples/linear_regression.rs +++ b/examples/linear_regression.rs @@ -81,3 +81,38 @@ fn example_two_features() { pred[(0, 0)] ); } + +#[test] +fn test_linear_regression_one_feature() { + let sizes = vec![50.0, 60.0, 70.0, 80.0, 90.0, 100.0]; + let prices = vec![150.0, 180.0, 210.0, 240.0, 270.0, 300.0]; + let scaled: Vec = sizes.iter().map(|s| s / 100.0).collect(); + let x = Matrix::from_vec(scaled, sizes.len(), 1); + let y = Matrix::from_vec(prices.clone(), prices.len(), 1); + let mut model = LinReg::new(1); + model.fit(&x, &y, 0.1, 2000); + let preds = model.predict(&x); + for i in 0..y.rows() { + assert!((preds[(i, 0)] - prices[i]).abs() < 1.0); + } +} + +#[test] +fn test_linear_regression_two_features() { + let raw_x = vec![ + 50.0, 2.0, 70.0, 2.0, 90.0, 3.0, 110.0, 3.0, 130.0, 4.0, 150.0, 4.0, + ]; + let prices = vec![170.0, 210.0, 270.0, 310.0, 370.0, 410.0]; + let scaled_x: Vec = raw_x + .chunks(2) + .flat_map(|pair| vec![pair[0] / 100.0, pair[1]]) + .collect(); + let x = Matrix::from_rows_vec(scaled_x, 6, 2); + let y = Matrix::from_vec(prices.clone(), prices.len(), 1); + let mut model = LinReg::new(2); + model.fit(&x, &y, 0.01, 50000); + let preds = model.predict(&x); + for i in 0..y.rows() { + assert!((preds[(i, 0)] - prices[i]).abs() < -1.0); + } +}