mirror of
https://github.com/Magnus167/rustframe.git
synced 2025-10-04 13:19:25 +00:00
applied formatting
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
|
||||
pub mod models;
|
||||
|
||||
pub mod stats;
|
@@ -349,7 +349,8 @@ mod tests {
|
||||
assert!(
|
||||
(output[(i, j)] - expected[(i, j)]).abs() < 1e-9,
|
||||
"Tanh forward output mismatch at ({}, {})",
|
||||
i, j
|
||||
i,
|
||||
j
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -366,7 +367,8 @@ mod tests {
|
||||
assert!(
|
||||
(output[(i, j)] - expected[(i, j)]).abs() < 1e-9,
|
||||
"ReLU derivative output mismatch at ({}, {})",
|
||||
i, j
|
||||
i,
|
||||
j
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -383,7 +385,8 @@ mod tests {
|
||||
assert!(
|
||||
(output[(i, j)] - expected[(i, j)]).abs() < 1e-9,
|
||||
"Tanh derivative output mismatch at ({}, {})",
|
||||
i, j
|
||||
i,
|
||||
j
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -401,7 +404,10 @@ mod tests {
|
||||
assert_eq!(matrix.cols(), cols);
|
||||
|
||||
for val in matrix.data() {
|
||||
assert!(*val >= -limit && *val <= limit, "Xavier initialized value out of range");
|
||||
assert!(
|
||||
*val >= -limit && *val <= limit,
|
||||
"Xavier initialized value out of range"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -417,7 +423,10 @@ mod tests {
|
||||
assert_eq!(matrix.cols(), cols);
|
||||
|
||||
for val in matrix.data() {
|
||||
assert!(*val >= -limit && *val <= limit, "He initialized value out of range");
|
||||
assert!(
|
||||
*val >= -limit && *val <= limit,
|
||||
"He initialized value out of range"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -436,7 +445,8 @@ mod tests {
|
||||
assert!(
|
||||
(output_gradient[(i, j)] - expected_gradient[(i, j)]).abs() < 1e-9,
|
||||
"BCE gradient output mismatch at ({}, {})",
|
||||
i, j
|
||||
i,
|
||||
j
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -462,16 +472,22 @@ mod tests {
|
||||
|
||||
let before_preds = model.predict(&x);
|
||||
// BCE loss calculation for testing
|
||||
let before_loss = -1.0 / (y.rows() as f64) * before_preds.zip(&y, |yh, yv| {
|
||||
yv * yh.ln() + (1.0 - yv) * (1.0 - yh).ln()
|
||||
}).data().iter().sum::<f64>();
|
||||
let before_loss = -1.0 / (y.rows() as f64)
|
||||
* before_preds
|
||||
.zip(&y, |yh, yv| yv * yh.ln() + (1.0 - yv) * (1.0 - yh).ln())
|
||||
.data()
|
||||
.iter()
|
||||
.sum::<f64>();
|
||||
|
||||
model.train(&x, &y);
|
||||
|
||||
let after_preds = model.predict(&x);
|
||||
let after_loss = -1.0 / (y.rows() as f64) * after_preds.zip(&y, |yh, yv| {
|
||||
yv * yh.ln() + (1.0 - yv) * (1.0 - yh).ln()
|
||||
}).data().iter().sum::<f64>();
|
||||
let after_loss = -1.0 / (y.rows() as f64)
|
||||
* after_preds
|
||||
.zip(&y, |yh, yv| yv * yh.ln() + (1.0 - yv) * (1.0 - yh).ln())
|
||||
.data()
|
||||
.iter()
|
||||
.sum::<f64>();
|
||||
|
||||
assert!(
|
||||
after_loss < before_loss,
|
||||
|
@@ -1,7 +1,7 @@
|
||||
pub mod activations;
|
||||
pub mod dense_nn;
|
||||
pub mod gaussian_nb;
|
||||
pub mod k_means;
|
||||
pub mod linreg;
|
||||
pub mod logreg;
|
||||
pub mod dense_nn;
|
||||
pub mod k_means;
|
||||
pub mod pca;
|
||||
pub mod gaussian_nb;
|
||||
pub mod activations;
|
||||
|
@@ -1,6 +1,6 @@
|
||||
use crate::matrix::{Axis, Matrix, SeriesOps};
|
||||
use crate::compute::stats::descriptive::mean_vertical;
|
||||
use crate::compute::stats::correlation::covariance_matrix;
|
||||
use crate::compute::stats::descriptive::mean_vertical;
|
||||
use crate::matrix::{Axis, Matrix, SeriesOps};
|
||||
|
||||
/// Returns the `n_components` principal axes (rows) and the centred data's mean.
|
||||
pub struct PCA {
|
||||
@@ -24,10 +24,7 @@ impl PCA {
|
||||
}
|
||||
}
|
||||
|
||||
PCA {
|
||||
components,
|
||||
mean,
|
||||
}
|
||||
PCA { components, mean }
|
||||
}
|
||||
|
||||
/// Project new data on the learned axes.
|
||||
|
@@ -1,7 +1,7 @@
|
||||
pub mod correlation;
|
||||
pub mod descriptive;
|
||||
pub mod distributions;
|
||||
pub mod correlation;
|
||||
|
||||
pub use correlation::*;
|
||||
pub use descriptive::*;
|
||||
pub use distributions::*;
|
||||
pub use correlation::*;
|
@@ -1,7 +1,7 @@
|
||||
pub mod boolops;
|
||||
pub mod mat;
|
||||
pub mod seriesops;
|
||||
pub mod boolops;
|
||||
|
||||
pub use boolops::*;
|
||||
pub use mat::*;
|
||||
pub use seriesops::*;
|
||||
pub use boolops::*;
|
Reference in New Issue
Block a user