mirror of
https://github.com/Magnus167/rustframe.git
synced 2025-08-20 08:30:01 +00:00
Compare commits
6 Commits
b7480b20d4
...
7b0d34384a
Author | SHA1 | Date | |
---|---|---|---|
![]() |
7b0d34384a | ||
![]() |
9182ab9fca | ||
![]() |
de18d8e010 | ||
![]() |
9b08eaeb35 | ||
![]() |
a3bb509202 | ||
![]() |
10018f7efe |
@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
pub mod models;
|
pub mod models;
|
||||||
|
|
||||||
pub mod stats;
|
pub mod stats;
|
||||||
|
@ -264,10 +264,8 @@ mod tests {
|
|||||||
|
|
||||||
for i in 0..before.rows() {
|
for i in 0..before.rows() {
|
||||||
for j in 0..before.cols() {
|
for j in 0..before.cols() {
|
||||||
assert!(
|
// "prediction changed despite 0 epochs"
|
||||||
(before[(i, j)] - after[(i, j)]).abs() < 1e-12,
|
assert!((before[(i, j)] - after[(i, j)]).abs() < 1e-12);
|
||||||
"prediction changed despite 0 epochs"
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -330,12 +328,8 @@ mod tests {
|
|||||||
let after_preds = model.predict(&x);
|
let after_preds = model.predict(&x);
|
||||||
let after_loss = mse_loss(&after_preds, &y);
|
let after_loss = mse_loss(&after_preds, &y);
|
||||||
|
|
||||||
assert!(
|
// MSE did not decrease (before: {}, after: {})
|
||||||
after_loss < before_loss,
|
assert!(after_loss < before_loss);
|
||||||
"MSE did not decrease (before: {}, after: {})",
|
|
||||||
before_loss,
|
|
||||||
after_loss
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -346,11 +340,8 @@ mod tests {
|
|||||||
|
|
||||||
for i in 0..input.rows() {
|
for i in 0..input.rows() {
|
||||||
for j in 0..input.cols() {
|
for j in 0..input.cols() {
|
||||||
assert!(
|
// Tanh forward output mismatch at ({}, {})
|
||||||
(output[(i, j)] - expected[(i, j)]).abs() < 1e-9,
|
assert!((output[(i, j)] - expected[(i, j)]).abs() < 1e-9);
|
||||||
"Tanh forward output mismatch at ({}, {})",
|
|
||||||
i, j
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -363,11 +354,8 @@ mod tests {
|
|||||||
|
|
||||||
for i in 0..input.rows() {
|
for i in 0..input.rows() {
|
||||||
for j in 0..input.cols() {
|
for j in 0..input.cols() {
|
||||||
assert!(
|
// "ReLU derivative output mismatch at ({}, {})"
|
||||||
(output[(i, j)] - expected[(i, j)]).abs() < 1e-9,
|
assert!((output[(i, j)] - expected[(i, j)]).abs() < 1e-9);
|
||||||
"ReLU derivative output mismatch at ({}, {})",
|
|
||||||
i, j
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -380,11 +368,8 @@ mod tests {
|
|||||||
|
|
||||||
for i in 0..input.rows() {
|
for i in 0..input.rows() {
|
||||||
for j in 0..input.cols() {
|
for j in 0..input.cols() {
|
||||||
assert!(
|
// "Tanh derivative output mismatch at ({}, {})"
|
||||||
(output[(i, j)] - expected[(i, j)]).abs() < 1e-9,
|
assert!((output[(i, j)] - expected[(i, j)]).abs() < 1e-9);
|
||||||
"Tanh derivative output mismatch at ({}, {})",
|
|
||||||
i, j
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -401,7 +386,8 @@ mod tests {
|
|||||||
assert_eq!(matrix.cols(), cols);
|
assert_eq!(matrix.cols(), cols);
|
||||||
|
|
||||||
for val in matrix.data() {
|
for val in matrix.data() {
|
||||||
assert!(*val >= -limit && *val <= limit, "Xavier initialized value out of range");
|
// Xavier initialized value out of range
|
||||||
|
assert!(*val >= -limit && *val <= limit);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -417,7 +403,8 @@ mod tests {
|
|||||||
assert_eq!(matrix.cols(), cols);
|
assert_eq!(matrix.cols(), cols);
|
||||||
|
|
||||||
for val in matrix.data() {
|
for val in matrix.data() {
|
||||||
assert!(*val >= -limit && *val <= limit, "He initialized value out of range");
|
// He initialized value out of range
|
||||||
|
assert!(*val >= -limit && *val <= limit);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -433,11 +420,8 @@ mod tests {
|
|||||||
|
|
||||||
for i in 0..output_gradient.rows() {
|
for i in 0..output_gradient.rows() {
|
||||||
for j in 0..output_gradient.cols() {
|
for j in 0..output_gradient.cols() {
|
||||||
assert!(
|
// BCE gradient output mismatch at ({}, {})
|
||||||
(output_gradient[(i, j)] - expected_gradient[(i, j)]).abs() < 1e-9,
|
assert!((output_gradient[(i, j)] - expected_gradient[(i, j)]).abs() < 1e-9);
|
||||||
"BCE gradient output mismatch at ({}, {})",
|
|
||||||
i, j
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -462,23 +446,25 @@ mod tests {
|
|||||||
|
|
||||||
let before_preds = model.predict(&x);
|
let before_preds = model.predict(&x);
|
||||||
// BCE loss calculation for testing
|
// BCE loss calculation for testing
|
||||||
let before_loss = -1.0 / (y.rows() as f64) * before_preds.zip(&y, |yh, yv| {
|
let before_loss = -1.0 / (y.rows() as f64)
|
||||||
yv * yh.ln() + (1.0 - yv) * (1.0 - yh).ln()
|
* before_preds
|
||||||
}).data().iter().sum::<f64>();
|
.zip(&y, |yh, yv| yv * yh.ln() + (1.0 - yv) * (1.0 - yh).ln())
|
||||||
|
.data()
|
||||||
|
.iter()
|
||||||
|
.sum::<f64>();
|
||||||
|
|
||||||
model.train(&x, &y);
|
model.train(&x, &y);
|
||||||
|
|
||||||
let after_preds = model.predict(&x);
|
let after_preds = model.predict(&x);
|
||||||
let after_loss = -1.0 / (y.rows() as f64) * after_preds.zip(&y, |yh, yv| {
|
let after_loss = -1.0 / (y.rows() as f64)
|
||||||
yv * yh.ln() + (1.0 - yv) * (1.0 - yh).ln()
|
* after_preds
|
||||||
}).data().iter().sum::<f64>();
|
.zip(&y, |yh, yv| yv * yh.ln() + (1.0 - yv) * (1.0 - yh).ln())
|
||||||
|
.data()
|
||||||
|
.iter()
|
||||||
|
.sum::<f64>();
|
||||||
|
|
||||||
assert!(
|
// BCE did not decrease (before: {}, after: {})
|
||||||
after_loss < before_loss,
|
assert!(after_loss < before_loss,);
|
||||||
"BCE did not decrease (before: {}, after: {})",
|
|
||||||
before_loss,
|
|
||||||
after_loss
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -509,21 +495,15 @@ mod tests {
|
|||||||
|
|
||||||
// Verify that weights and biases of both layers have changed,
|
// Verify that weights and biases of both layers have changed,
|
||||||
// implying delta propagation occurred for l > 0
|
// implying delta propagation occurred for l > 0
|
||||||
assert!(
|
|
||||||
model.weights[0] != initial_weights_l0,
|
|
||||||
"Weights of first layer did not change, delta propagation might not have occurred"
|
// Weights of first layer did not change, delta propagation might not have occurred
|
||||||
);
|
assert!(model.weights[0] != initial_weights_l0);
|
||||||
assert!(
|
// Biases of first layer did not change, delta propagation might not have occurred
|
||||||
model.biases[0] != initial_biases_l0,
|
assert!(model.biases[0] != initial_biases_l0);
|
||||||
"Biases of first layer did not change, delta propagation might not have occurred"
|
// Weights of second layer did not change
|
||||||
);
|
assert!(model.weights[1] != initial_weights_l1);
|
||||||
assert!(
|
// Biases of second layer did not change
|
||||||
model.weights[1] != initial_weights_l1,
|
assert!(model.biases[1] != initial_biases_l1);
|
||||||
"Weights of second layer did not change"
|
|
||||||
);
|
|
||||||
assert!(
|
|
||||||
model.biases[1] != initial_biases_l1,
|
|
||||||
"Biases of second layer did not change"
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
|
pub mod activations;
|
||||||
|
pub mod dense_nn;
|
||||||
|
pub mod gaussian_nb;
|
||||||
|
pub mod k_means;
|
||||||
pub mod linreg;
|
pub mod linreg;
|
||||||
pub mod logreg;
|
pub mod logreg;
|
||||||
pub mod dense_nn;
|
|
||||||
pub mod k_means;
|
|
||||||
pub mod pca;
|
pub mod pca;
|
||||||
pub mod gaussian_nb;
|
|
||||||
pub mod activations;
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
use crate::matrix::{Axis, Matrix, SeriesOps};
|
|
||||||
use crate::compute::stats::descriptive::mean_vertical;
|
|
||||||
use crate::compute::stats::correlation::covariance_matrix;
|
use crate::compute::stats::correlation::covariance_matrix;
|
||||||
|
use crate::compute::stats::descriptive::mean_vertical;
|
||||||
|
use crate::matrix::{Axis, Matrix, SeriesOps};
|
||||||
|
|
||||||
/// Returns the `n_components` principal axes (rows) and the centred data's mean.
|
/// Returns the `n_components` principal axes (rows) and the centred data's mean.
|
||||||
pub struct PCA {
|
pub struct PCA {
|
||||||
@ -24,10 +24,7 @@ impl PCA {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
PCA {
|
PCA { components, mean }
|
||||||
components,
|
|
||||||
mean,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Project new data on the learned axes.
|
/// Project new data on the learned axes.
|
||||||
@ -53,7 +50,7 @@ mod tests {
|
|||||||
// 2.0, 2.0
|
// 2.0, 2.0
|
||||||
// 3.0, 3.0
|
// 3.0, 3.0
|
||||||
let data = Matrix::from_rows_vec(vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0], 3, 2);
|
let data = Matrix::from_rows_vec(vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0], 3, 2);
|
||||||
let (n_samples, n_features) = data.shape();
|
let (_n_samples, _n_features) = data.shape();
|
||||||
|
|
||||||
let pca = PCA::fit(&data, 1, 0); // n_components = 1, iters is unused
|
let pca = PCA::fit(&data, 1, 0); // n_components = 1, iters is unused
|
||||||
|
|
||||||
@ -90,4 +87,28 @@ mod tests {
|
|||||||
assert!((transformed_data.get(1, 0) - 0.0).abs() < EPSILON);
|
assert!((transformed_data.get(1, 0) - 0.0).abs() < EPSILON);
|
||||||
assert!((transformed_data.get(2, 0) - 2.0).abs() < EPSILON);
|
assert!((transformed_data.get(2, 0) - 2.0).abs() < EPSILON);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_pca_fit_break_branch() {
|
||||||
|
// Data with 2 features
|
||||||
|
let data = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2);
|
||||||
|
let (_n_samples, n_features) = data.shape();
|
||||||
|
|
||||||
|
// Set n_components greater than n_features to trigger the break branch
|
||||||
|
let n_components_large = n_features + 1;
|
||||||
|
let pca = PCA::fit(&data, n_components_large, 0);
|
||||||
|
|
||||||
|
// The components matrix should be initialized with n_components_large rows,
|
||||||
|
// but only the first n_features rows should be copied from the covariance matrix.
|
||||||
|
// The remaining rows should be zeros.
|
||||||
|
assert_eq!(pca.components.rows(), n_components_large);
|
||||||
|
assert_eq!(pca.components.cols(), n_features);
|
||||||
|
|
||||||
|
// Verify that rows beyond n_features are all zeros
|
||||||
|
for i in n_features..n_components_large {
|
||||||
|
for j in 0..n_features {
|
||||||
|
assert!((pca.components.get(i, j) - 0.0).abs() < EPSILON);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -82,16 +82,30 @@ pub fn covariance_horizontal(x: &Matrix<f64>) -> Matrix<f64> {
|
|||||||
/// Calculates the covariance matrix of the input data.
|
/// Calculates the covariance matrix of the input data.
|
||||||
/// Assumes input `x` is (n_samples, n_features).
|
/// Assumes input `x` is (n_samples, n_features).
|
||||||
pub fn covariance_matrix(x: &Matrix<f64>, axis: Axis) -> Matrix<f64> {
|
pub fn covariance_matrix(x: &Matrix<f64>, axis: Axis) -> Matrix<f64> {
|
||||||
let (n_samples, _n_features) = x.shape();
|
let (n_samples, n_features) = x.shape();
|
||||||
|
|
||||||
let mean_matrix = match axis {
|
let centered_data = match axis {
|
||||||
Axis::Col => mean_vertical(x), // Mean of each feature (column)
|
Axis::Col => {
|
||||||
Axis::Row => mean_horizontal(x), // Mean of each sample (row)
|
let mean_matrix = mean_vertical(x); // 1 x n_features
|
||||||
|
x.zip(
|
||||||
|
&mean_matrix.broadcast_row_to_target_shape(n_samples, n_features),
|
||||||
|
|val, m| val - m,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
Axis::Row => {
|
||||||
|
let mean_matrix = mean_horizontal(x); // n_samples x 1
|
||||||
|
// Manually create a matrix by broadcasting the column vector across columns
|
||||||
|
let mut broadcasted_mean = Matrix::zeros(n_samples, n_features);
|
||||||
|
for r in 0..n_samples {
|
||||||
|
let mean_val = mean_matrix.get(r, 0);
|
||||||
|
for c in 0..n_features {
|
||||||
|
*broadcasted_mean.get_mut(r, c) = *mean_val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
x.zip(&broadcasted_mean, |val, m| val - m)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Center the data
|
|
||||||
let centered_data = x.zip(&mean_matrix.broadcast_row_to_target_shape(n_samples, x.cols()), |val, m| val - m);
|
|
||||||
|
|
||||||
// Calculate covariance matrix: (X_centered^T * X_centered) / (n_samples - 1)
|
// Calculate covariance matrix: (X_centered^T * X_centered) / (n_samples - 1)
|
||||||
// If x is (n_samples, n_features), then centered_data is (n_samples, n_features)
|
// If x is (n_samples, n_features), then centered_data is (n_samples, n_features)
|
||||||
// centered_data.transpose() is (n_features, n_samples)
|
// centered_data.transpose() is (n_features, n_samples)
|
||||||
@ -148,13 +162,7 @@ mod tests {
|
|||||||
// Expect 2x2 matrix of all 1.0
|
// Expect 2x2 matrix of all 1.0
|
||||||
for i in 0..2 {
|
for i in 0..2 {
|
||||||
for j in 0..2 {
|
for j in 0..2 {
|
||||||
assert!(
|
assert!((cov_mat.get(i, j) - 1.0).abs() < EPS);
|
||||||
(cov_mat.get(i, j) - 1.0).abs() < EPS,
|
|
||||||
"cov_mat[{},{}] = {}",
|
|
||||||
i,
|
|
||||||
j,
|
|
||||||
cov_mat.get(i, j)
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -171,14 +179,58 @@ mod tests {
|
|||||||
// Expect 2x2 matrix of all 0.25
|
// Expect 2x2 matrix of all 0.25
|
||||||
for i in 0..2 {
|
for i in 0..2 {
|
||||||
for j in 0..2 {
|
for j in 0..2 {
|
||||||
assert!(
|
assert!((cov_mat.get(i, j) - 0.25).abs() < EPS);
|
||||||
(cov_mat.get(i, j) - 0.25).abs() < EPS,
|
|
||||||
"cov_mat[{},{}] = {}",
|
|
||||||
i,
|
|
||||||
j,
|
|
||||||
cov_mat.get(i, j)
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_covariance_matrix_vertical() {
|
||||||
|
// Test with a simple 2x2 matrix
|
||||||
|
// M =
|
||||||
|
// 1, 2
|
||||||
|
// 3, 4
|
||||||
|
// Expected covariance matrix (vertical, i.e., between columns):
|
||||||
|
// Col1: [1, 3], mean = 2
|
||||||
|
// Col2: [2, 4], mean = 3
|
||||||
|
// Cov(Col1, Col1) = ((1-2)^2 + (3-2)^2) / (2-1) = (1+1)/1 = 2
|
||||||
|
// Cov(Col2, Col2) = ((2-3)^2 + (4-3)^2) / (2-1) = (1+1)/1 = 2
|
||||||
|
// Cov(Col1, Col2) = ((1-2)*(2-3) + (3-2)*(4-3)) / (2-1) = ((-1)*(-1) + (1)*(1))/1 = (1+1)/1 = 2
|
||||||
|
// Cov(Col2, Col1) = 2
|
||||||
|
// Expected:
|
||||||
|
// 2, 2
|
||||||
|
// 2, 2
|
||||||
|
let m = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
|
||||||
|
let cov_mat = covariance_matrix(&m, Axis::Col);
|
||||||
|
|
||||||
|
assert!((cov_mat.get(0, 0) - 2.0).abs() < EPS);
|
||||||
|
assert!((cov_mat.get(0, 1) - 2.0).abs() < EPS);
|
||||||
|
assert!((cov_mat.get(1, 0) - 2.0).abs() < EPS);
|
||||||
|
assert!((cov_mat.get(1, 1) - 2.0).abs() < EPS);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_covariance_matrix_horizontal() {
|
||||||
|
// Test with a simple 2x2 matrix
|
||||||
|
// M =
|
||||||
|
// 1, 2
|
||||||
|
// 3, 4
|
||||||
|
// Expected covariance matrix (horizontal, i.e., between rows):
|
||||||
|
// Row1: [1, 2], mean = 1.5
|
||||||
|
// Row2: [3, 4], mean = 3.5
|
||||||
|
// Cov(Row1, Row1) = ((1-1.5)^2 + (2-1.5)^2) / (2-1) = (0.25+0.25)/1 = 0.5
|
||||||
|
// Cov(Row2, Row2) = ((3-3.5)^2 + (4-3.5)^2) / (2-1) = (0.25+0.25)/1 = 0.5
|
||||||
|
// Cov(Row1, Row2) = ((1-1.5)*(3-3.5) + (2-1.5)*(4-3.5)) / (2-1) = ((-0.5)*(-0.5) + (0.5)*(0.5))/1 = (0.25+0.25)/1 = 0.5
|
||||||
|
// Cov(Row2, Row1) = 0.5
|
||||||
|
// Expected:
|
||||||
|
// 0.5, -0.5
|
||||||
|
// -0.5, 0.5
|
||||||
|
let m = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
|
||||||
|
let cov_mat = covariance_matrix(&m, Axis::Row);
|
||||||
|
|
||||||
|
assert!((cov_mat.get(0, 0) - 0.5).abs() < EPS);
|
||||||
|
assert!((cov_mat.get(0, 1) - (-0.5)).abs() < EPS);
|
||||||
|
assert!((cov_mat.get(1, 0) - (-0.5)).abs() < EPS);
|
||||||
|
assert!((cov_mat.get(1, 1) - 0.5).abs() < EPS);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
|
pub mod correlation;
|
||||||
pub mod descriptive;
|
pub mod descriptive;
|
||||||
pub mod distributions;
|
pub mod distributions;
|
||||||
pub mod correlation;
|
|
||||||
|
|
||||||
|
pub use correlation::*;
|
||||||
pub use descriptive::*;
|
pub use descriptive::*;
|
||||||
pub use distributions::*;
|
pub use distributions::*;
|
||||||
pub use correlation::*;
|
|
@ -4,4 +4,4 @@ pub mod ops;
|
|||||||
pub use base::*;
|
pub use base::*;
|
||||||
|
|
||||||
#[allow(unused_imports)]
|
#[allow(unused_imports)]
|
||||||
pub use ops::*;
|
pub use ops::*;
|
||||||
|
@ -171,7 +171,7 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_bool_ops_count_overall() {
|
fn test_bool_ops_count_overall() {
|
||||||
let matrix = create_bool_test_matrix(); // Data: [T, F, T, F, T, F, T, F, F]
|
let matrix = create_bool_test_matrix(); // Data: [T, F, T, F, T, F, T, F, F]
|
||||||
// Count of true values: 4
|
// Count of true values: 4
|
||||||
assert_eq!(matrix.count(), 4);
|
assert_eq!(matrix.count(), 4);
|
||||||
|
|
||||||
let matrix_all_false = BoolMatrix::from_vec(vec![false; 5], 5, 1); // 5x1
|
let matrix_all_false = BoolMatrix::from_vec(vec![false; 5], 5, 1); // 5x1
|
||||||
@ -211,7 +211,7 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_bool_ops_1xn_matrix() {
|
fn test_bool_ops_1xn_matrix() {
|
||||||
let matrix = BoolMatrix::from_vec(vec![true, false, false, true], 1, 4); // 1 row, 4 cols
|
let matrix = BoolMatrix::from_vec(vec![true, false, false, true], 1, 4); // 1 row, 4 cols
|
||||||
// Data: [T, F, F, T]
|
// Data: [T, F, F, T]
|
||||||
|
|
||||||
assert_eq!(matrix.any_vertical(), vec![true, false, false, true]);
|
assert_eq!(matrix.any_vertical(), vec![true, false, false, true]);
|
||||||
assert_eq!(matrix.all_vertical(), vec![true, false, false, true]);
|
assert_eq!(matrix.all_vertical(), vec![true, false, false, true]);
|
||||||
@ -229,7 +229,7 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_bool_ops_nx1_matrix() {
|
fn test_bool_ops_nx1_matrix() {
|
||||||
let matrix = BoolMatrix::from_vec(vec![true, false, false, true], 4, 1); // 4 rows, 1 col
|
let matrix = BoolMatrix::from_vec(vec![true, false, false, true], 4, 1); // 4 rows, 1 col
|
||||||
// Data: [T, F, F, T]
|
// Data: [T, F, F, T]
|
||||||
|
|
||||||
assert_eq!(matrix.any_vertical(), vec![true]); // T|F|F|T = T
|
assert_eq!(matrix.any_vertical(), vec![true]); // T|F|F|T = T
|
||||||
assert_eq!(matrix.all_vertical(), vec![false]); // T&F&F&T = F
|
assert_eq!(matrix.all_vertical(), vec![false]); // T&F&F&T = F
|
||||||
|
@ -386,15 +386,31 @@ impl<T: Clone> Matrix<T> {
|
|||||||
|
|
||||||
/// Creates a new matrix by broadcasting a 1-row matrix to a target shape.
|
/// Creates a new matrix by broadcasting a 1-row matrix to a target shape.
|
||||||
/// Panics if `self` is not a 1-row matrix or if `self.cols()` does not match `target_cols`.
|
/// Panics if `self` is not a 1-row matrix or if `self.cols()` does not match `target_cols`.
|
||||||
pub fn broadcast_row_to_target_shape(&self, target_rows: usize, target_cols: usize) -> Matrix<T> {
|
pub fn broadcast_row_to_target_shape(
|
||||||
assert_eq!(self.rows(), 1, "broadcast_row_to_target_shape can only be called on a 1-row matrix.");
|
&self,
|
||||||
assert_eq!(self.cols(), target_cols, "Column count mismatch for broadcasting: source has {} columns, target has {} columns.", self.cols(), target_cols);
|
target_rows: usize,
|
||||||
|
target_cols: usize,
|
||||||
|
) -> Matrix<T> {
|
||||||
|
assert_eq!(
|
||||||
|
self.rows(),
|
||||||
|
1,
|
||||||
|
"broadcast_row_to_target_shape can only be called on a 1-row matrix."
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
self.cols(),
|
||||||
|
target_cols,
|
||||||
|
"Column count mismatch for broadcasting: source has {} columns, target has {} columns.",
|
||||||
|
self.cols(),
|
||||||
|
target_cols
|
||||||
|
);
|
||||||
|
|
||||||
let mut data = Vec::with_capacity(target_rows * target_cols);
|
let mut data = Vec::with_capacity(target_rows * target_cols);
|
||||||
let original_row_data = self.row(0); // Get the single row data
|
let original_row_data = self.row(0); // Get the single row data
|
||||||
|
|
||||||
for _ in 0..target_rows { // Repeat 'target_rows' times
|
for _ in 0..target_rows {
|
||||||
for value in &original_row_data { // Iterate over elements of the row
|
// Repeat 'target_rows' times
|
||||||
|
for value in &original_row_data {
|
||||||
|
// Iterate over elements of the row
|
||||||
data.push(value.clone());
|
data.push(value.clone());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1250,6 +1266,13 @@ mod tests {
|
|||||||
ma.row_copy_from_slice(1, &new_row);
|
ma.row_copy_from_slice(1, &new_row);
|
||||||
assert_eq!(ma.row(1), &[10, 20, 30]);
|
assert_eq!(ma.row(1), &[10, 20, 30]);
|
||||||
}
|
}
|
||||||
|
#[test]
|
||||||
|
#[should_panic(expected = "row index 4 out of bounds for 3 rows")]
|
||||||
|
fn test_row_copy_from_slice_out_of_bounds() {
|
||||||
|
let mut ma = static_test_matrix();
|
||||||
|
let new_row = vec![10, 20, 30];
|
||||||
|
ma.row_copy_from_slice(4, &new_row);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[should_panic(expected = "row index 3 out of bounds for 3 rows")]
|
#[should_panic(expected = "row index 3 out of bounds for 3 rows")]
|
||||||
@ -2042,14 +2065,18 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[should_panic(expected = "broadcast_row_to_target_shape can only be called on a 1-row matrix.")]
|
#[should_panic(
|
||||||
|
expected = "broadcast_row_to_target_shape can only be called on a 1-row matrix."
|
||||||
|
)]
|
||||||
fn test_broadcast_row_to_target_shape_panic_not_1_row() {
|
fn test_broadcast_row_to_target_shape_panic_not_1_row() {
|
||||||
let multi_row_matrix = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
|
let multi_row_matrix = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
|
||||||
multi_row_matrix.broadcast_row_to_target_shape(3, 2);
|
multi_row_matrix.broadcast_row_to_target_shape(3, 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[should_panic(expected = "Column count mismatch for broadcasting: source has 3 columns, target has 4 columns.")]
|
#[should_panic(
|
||||||
|
expected = "Column count mismatch for broadcasting: source has 3 columns, target has 4 columns."
|
||||||
|
)]
|
||||||
fn test_broadcast_row_to_target_shape_panic_col_mismatch() {
|
fn test_broadcast_row_to_target_shape_panic_col_mismatch() {
|
||||||
let single_row_matrix = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0], 1, 3);
|
let single_row_matrix = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0], 1, 3);
|
||||||
single_row_matrix.broadcast_row_to_target_shape(5, 4);
|
single_row_matrix.broadcast_row_to_target_shape(5, 4);
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
|
pub mod boolops;
|
||||||
pub mod mat;
|
pub mod mat;
|
||||||
pub mod seriesops;
|
pub mod seriesops;
|
||||||
pub mod boolops;
|
|
||||||
|
|
||||||
|
pub use boolops::*;
|
||||||
pub use mat::*;
|
pub use mat::*;
|
||||||
pub use seriesops::*;
|
pub use seriesops::*;
|
||||||
pub use boolops::*;
|
|
Loading…
x
Reference in New Issue
Block a user