From 4ddacdfd21b6fe80d090491a10aed1a96e0b7383 Mon Sep 17 00:00:00 2001 From: Palash Tyagi <23239946+Magnus167@users.noreply.github.com> Date: Sun, 6 Jul 2025 18:52:15 +0100 Subject: [PATCH] Add unit tests for linear regression fit and predict methods --- src/compute/models/linreg.rs | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/compute/models/linreg.rs b/src/compute/models/linreg.rs index 74e5e44..5add9f9 100644 --- a/src/compute/models/linreg.rs +++ b/src/compute/models/linreg.rs @@ -33,3 +33,22 @@ impl LinReg { } } } + +mod tests { + + use super::LinReg; + use crate::matrix::{Matrix}; + + #[test] + fn test_linreg_fit_predict() { + let x = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 4, 1); + let y = Matrix::from_vec(vec![2.0, 3.0, 4.0, 5.0], 4, 1); + let mut model = LinReg::new(1); + model.fit(&x, &y, 0.01, 10000); + let preds = model.predict(&x); + assert!((preds[(0, 0)] - 2.0).abs() < 1e-2); + assert!((preds[(1, 0)] - 3.0).abs() < 1e-2); + assert!((preds[(2, 0)] - 4.0).abs() < 1e-2); + assert!((preds[(3, 0)] - 5.0).abs() < 1e-2); + } +}