Compare commits

..

No commits in common. "a08fb546a9fcaddd214064292fc83c31c4644dcd" and "4f8a27298cda5002e27fa8a091821f51131b13aa" have entirely different histories.

6 changed files with 25 additions and 23 deletions

View File

@ -1,4 +1,3 @@
pub mod activations;
pub mod models; pub mod models;
pub mod stats;

View File

@ -1,4 +1,4 @@
use crate::compute::models::activations::{drelu, relu, sigmoid}; use crate::compute::activations::{drelu, relu, sigmoid};
use crate::matrix::{Matrix, SeriesOps}; use crate::matrix::{Matrix, SeriesOps};
use rand::prelude::*; use rand::prelude::*;

View File

@ -1,4 +1,4 @@
use crate::compute::models::activations::sigmoid; use crate::compute::activations::sigmoid;
use crate::matrix::{Matrix, SeriesOps}; use crate::matrix::{Matrix, SeriesOps};
pub struct LogReg { pub struct LogReg {

View File

@ -4,4 +4,3 @@ pub mod dense_nn;
pub mod k_means; pub mod k_means;
pub mod pca; pub mod pca;
pub mod gaussian_nb; pub mod gaussian_nb;
pub mod activations;

View File

@ -197,24 +197,17 @@ impl<T: Clone> Matrix<T> {
} }
row_data row_data
} }
pub fn row_copy_from_slice(&mut self, r: usize, values: &[T]) {
#[inline]
pub fn row_mut(&mut self, r: usize) -> &mut [T] {
assert!( assert!(
r < self.rows, r < self.rows,
"row index {} out of bounds for {} rows", "row index {} out of bounds for {} rows",
r, r,
self.rows self.rows
); );
assert!( let start = r;
values.len() == self.cols, &mut self.data[start..start + self.cols]
"input slice length {} does not match number of columns {}",
values.len(),
self.cols
);
for (c, value) in values.iter().enumerate() {
let idx = r + c * self.rows; // column-major index
self.data[idx] = value.clone();
}
} }
/// Deletes a row from the matrix. Panics on out-of-bounds. /// Deletes a row from the matrix. Panics on out-of-bounds.
@ -1192,11 +1185,22 @@ mod tests {
} }
#[test] #[test]
fn test_row_copy_from_slice() { fn test_row_mut() {
let mut ma = static_test_matrix(); let mut ma = static_test_matrix();
let new_row = vec![10, 20, 30]; let row1_mut = ma.row_mut(1);
ma.row_copy_from_slice(1, &new_row); row1_mut[0] = 20;
assert_eq!(ma.row(1), &[10, 20, 30]); row1_mut[1] = 50;
row1_mut[2] = 80;
assert_eq!(ma.row(1), &[20, 50, 80]);
assert_eq!(ma.data(), &[1, 2, 3, 20, 50, 80, 7, 8, 9]);
}
#[test]
#[should_panic(expected = "row index 3 out of bounds for 3 rows")]
fn test_row_mut_out_of_bounds() {
let mut ma = static_test_matrix();
ma.row_mut(3);
} }
#[test] #[test]