applied formatting

This commit is contained in:
Palash Tyagi 2025-07-12 00:56:09 +01:00
parent 9b08eaeb35
commit de18d8e010
9 changed files with 45 additions and 33 deletions

View File

@ -1,4 +1,3 @@
pub mod models; pub mod models;
pub mod stats; pub mod stats;

View File

@ -349,7 +349,8 @@ mod tests {
assert!( assert!(
(output[(i, j)] - expected[(i, j)]).abs() < 1e-9, (output[(i, j)] - expected[(i, j)]).abs() < 1e-9,
"Tanh forward output mismatch at ({}, {})", "Tanh forward output mismatch at ({}, {})",
i, j i,
j
); );
} }
} }
@ -366,7 +367,8 @@ mod tests {
assert!( assert!(
(output[(i, j)] - expected[(i, j)]).abs() < 1e-9, (output[(i, j)] - expected[(i, j)]).abs() < 1e-9,
"ReLU derivative output mismatch at ({}, {})", "ReLU derivative output mismatch at ({}, {})",
i, j i,
j
); );
} }
} }
@ -383,7 +385,8 @@ mod tests {
assert!( assert!(
(output[(i, j)] - expected[(i, j)]).abs() < 1e-9, (output[(i, j)] - expected[(i, j)]).abs() < 1e-9,
"Tanh derivative output mismatch at ({}, {})", "Tanh derivative output mismatch at ({}, {})",
i, j i,
j
); );
} }
} }
@ -401,7 +404,10 @@ mod tests {
assert_eq!(matrix.cols(), cols); assert_eq!(matrix.cols(), cols);
for val in matrix.data() { for val in matrix.data() {
assert!(*val >= -limit && *val <= limit, "Xavier initialized value out of range"); assert!(
*val >= -limit && *val <= limit,
"Xavier initialized value out of range"
);
} }
} }
@ -417,7 +423,10 @@ mod tests {
assert_eq!(matrix.cols(), cols); assert_eq!(matrix.cols(), cols);
for val in matrix.data() { for val in matrix.data() {
assert!(*val >= -limit && *val <= limit, "He initialized value out of range"); assert!(
*val >= -limit && *val <= limit,
"He initialized value out of range"
);
} }
} }
@ -436,7 +445,8 @@ mod tests {
assert!( assert!(
(output_gradient[(i, j)] - expected_gradient[(i, j)]).abs() < 1e-9, (output_gradient[(i, j)] - expected_gradient[(i, j)]).abs() < 1e-9,
"BCE gradient output mismatch at ({}, {})", "BCE gradient output mismatch at ({}, {})",
i, j i,
j
); );
} }
} }
@ -462,16 +472,22 @@ mod tests {
let before_preds = model.predict(&x); let before_preds = model.predict(&x);
// BCE loss calculation for testing // BCE loss calculation for testing
let before_loss = -1.0 / (y.rows() as f64) * before_preds.zip(&y, |yh, yv| { let before_loss = -1.0 / (y.rows() as f64)
yv * yh.ln() + (1.0 - yv) * (1.0 - yh).ln() * before_preds
}).data().iter().sum::<f64>(); .zip(&y, |yh, yv| yv * yh.ln() + (1.0 - yv) * (1.0 - yh).ln())
.data()
.iter()
.sum::<f64>();
model.train(&x, &y); model.train(&x, &y);
let after_preds = model.predict(&x); let after_preds = model.predict(&x);
let after_loss = -1.0 / (y.rows() as f64) * after_preds.zip(&y, |yh, yv| { let after_loss = -1.0 / (y.rows() as f64)
yv * yh.ln() + (1.0 - yv) * (1.0 - yh).ln() * after_preds
}).data().iter().sum::<f64>(); .zip(&y, |yh, yv| yv * yh.ln() + (1.0 - yv) * (1.0 - yh).ln())
.data()
.iter()
.sum::<f64>();
assert!( assert!(
after_loss < before_loss, after_loss < before_loss,

View File

@ -1,7 +1,7 @@
pub mod activations;
pub mod dense_nn;
pub mod gaussian_nb;
pub mod k_means;
pub mod linreg; pub mod linreg;
pub mod logreg; pub mod logreg;
pub mod dense_nn;
pub mod k_means;
pub mod pca; pub mod pca;
pub mod gaussian_nb;
pub mod activations;

View File

@ -1,6 +1,6 @@
use crate::matrix::{Axis, Matrix, SeriesOps};
use crate::compute::stats::descriptive::mean_vertical;
use crate::compute::stats::correlation::covariance_matrix; use crate::compute::stats::correlation::covariance_matrix;
use crate::compute::stats::descriptive::mean_vertical;
use crate::matrix::{Axis, Matrix, SeriesOps};
/// Returns the `n_components` principal axes (rows) and the centred data's mean. /// Returns the `n_components` principal axes (rows) and the centred data's mean.
pub struct PCA { pub struct PCA {
@ -24,10 +24,7 @@ impl PCA {
} }
} }
PCA { PCA { components, mean }
components,
mean,
}
} }
/// Project new data on the learned axes. /// Project new data on the learned axes.

View File

@ -94,7 +94,7 @@ pub fn covariance_matrix(x: &Matrix<f64>, axis: Axis) -> Matrix<f64> {
} }
Axis::Row => { Axis::Row => {
let mean_matrix = mean_horizontal(x); // n_samples x 1 let mean_matrix = mean_horizontal(x); // n_samples x 1
// Manually create a matrix by broadcasting the column vector across columns // Manually create a matrix by broadcasting the column vector across columns
let mut broadcasted_mean = Matrix::zeros(n_samples, n_features); let mut broadcasted_mean = Matrix::zeros(n_samples, n_features);
for r in 0..n_samples { for r in 0..n_samples {
let mean_val = mean_matrix.get(r, 0); let mean_val = mean_matrix.get(r, 0);

View File

@ -1,7 +1,7 @@
pub mod correlation;
pub mod descriptive; pub mod descriptive;
pub mod distributions; pub mod distributions;
pub mod correlation;
pub use correlation::*;
pub use descriptive::*; pub use descriptive::*;
pub use distributions::*; pub use distributions::*;
pub use correlation::*;

View File

@ -171,7 +171,7 @@ mod tests {
#[test] #[test]
fn test_bool_ops_count_overall() { fn test_bool_ops_count_overall() {
let matrix = create_bool_test_matrix(); // Data: [T, F, T, F, T, F, T, F, F] let matrix = create_bool_test_matrix(); // Data: [T, F, T, F, T, F, T, F, F]
// Count of true values: 4 // Count of true values: 4
assert_eq!(matrix.count(), 4); assert_eq!(matrix.count(), 4);
let matrix_all_false = BoolMatrix::from_vec(vec![false; 5], 5, 1); // 5x1 let matrix_all_false = BoolMatrix::from_vec(vec![false; 5], 5, 1); // 5x1
@ -211,7 +211,7 @@ mod tests {
#[test] #[test]
fn test_bool_ops_1xn_matrix() { fn test_bool_ops_1xn_matrix() {
let matrix = BoolMatrix::from_vec(vec![true, false, false, true], 1, 4); // 1 row, 4 cols let matrix = BoolMatrix::from_vec(vec![true, false, false, true], 1, 4); // 1 row, 4 cols
// Data: [T, F, F, T] // Data: [T, F, F, T]
assert_eq!(matrix.any_vertical(), vec![true, false, false, true]); assert_eq!(matrix.any_vertical(), vec![true, false, false, true]);
assert_eq!(matrix.all_vertical(), vec![true, false, false, true]); assert_eq!(matrix.all_vertical(), vec![true, false, false, true]);
@ -229,7 +229,7 @@ mod tests {
#[test] #[test]
fn test_bool_ops_nx1_matrix() { fn test_bool_ops_nx1_matrix() {
let matrix = BoolMatrix::from_vec(vec![true, false, false, true], 4, 1); // 4 rows, 1 col let matrix = BoolMatrix::from_vec(vec![true, false, false, true], 4, 1); // 4 rows, 1 col
// Data: [T, F, F, T] // Data: [T, F, F, T]
assert_eq!(matrix.any_vertical(), vec![true]); // T|F|F|T = T assert_eq!(matrix.any_vertical(), vec![true]); // T|F|F|T = T
assert_eq!(matrix.all_vertical(), vec![false]); // T&F&F&T = F assert_eq!(matrix.all_vertical(), vec![false]); // T&F&F&T = F

View File

@ -1,7 +1,7 @@
pub mod boolops;
pub mod mat; pub mod mat;
pub mod seriesops; pub mod seriesops;
pub mod boolops;
pub use boolops::*;
pub use mat::*; pub use mat::*;
pub use seriesops::*; pub use seriesops::*;
pub use boolops::*;