mirror of
https://github.com/Magnus167/rustframe.git
synced 2025-08-21 03:49:59 +00:00
Compare commits
No commits in common. "a08fb546a9fcaddd214064292fc83c31c4644dcd" and "4f8a27298cda5002e27fa8a091821f51131b13aa" have entirely different histories.
a08fb546a9
...
4f8a27298c
@ -1,4 +1,3 @@
|
|||||||
|
pub mod activations;
|
||||||
|
|
||||||
pub mod models;
|
pub mod models;
|
||||||
|
|
||||||
pub mod stats;
|
|
@ -1,4 +1,4 @@
|
|||||||
use crate::compute::models::activations::{drelu, relu, sigmoid};
|
use crate::compute::activations::{drelu, relu, sigmoid};
|
||||||
use crate::matrix::{Matrix, SeriesOps};
|
use crate::matrix::{Matrix, SeriesOps};
|
||||||
use rand::prelude::*;
|
use rand::prelude::*;
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use crate::compute::models::activations::sigmoid;
|
use crate::compute::activations::sigmoid;
|
||||||
use crate::matrix::{Matrix, SeriesOps};
|
use crate::matrix::{Matrix, SeriesOps};
|
||||||
|
|
||||||
pub struct LogReg {
|
pub struct LogReg {
|
||||||
|
@ -4,4 +4,3 @@ pub mod dense_nn;
|
|||||||
pub mod k_means;
|
pub mod k_means;
|
||||||
pub mod pca;
|
pub mod pca;
|
||||||
pub mod gaussian_nb;
|
pub mod gaussian_nb;
|
||||||
pub mod activations;
|
|
||||||
|
@ -197,24 +197,17 @@ impl<T: Clone> Matrix<T> {
|
|||||||
}
|
}
|
||||||
row_data
|
row_data
|
||||||
}
|
}
|
||||||
pub fn row_copy_from_slice(&mut self, r: usize, values: &[T]) {
|
|
||||||
|
#[inline]
|
||||||
|
pub fn row_mut(&mut self, r: usize) -> &mut [T] {
|
||||||
assert!(
|
assert!(
|
||||||
r < self.rows,
|
r < self.rows,
|
||||||
"row index {} out of bounds for {} rows",
|
"row index {} out of bounds for {} rows",
|
||||||
r,
|
r,
|
||||||
self.rows
|
self.rows
|
||||||
);
|
);
|
||||||
assert!(
|
let start = r;
|
||||||
values.len() == self.cols,
|
&mut self.data[start..start + self.cols]
|
||||||
"input slice length {} does not match number of columns {}",
|
|
||||||
values.len(),
|
|
||||||
self.cols
|
|
||||||
);
|
|
||||||
|
|
||||||
for (c, value) in values.iter().enumerate() {
|
|
||||||
let idx = r + c * self.rows; // column-major index
|
|
||||||
self.data[idx] = value.clone();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Deletes a row from the matrix. Panics on out-of-bounds.
|
/// Deletes a row from the matrix. Panics on out-of-bounds.
|
||||||
@ -1192,11 +1185,22 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_row_copy_from_slice() {
|
fn test_row_mut() {
|
||||||
let mut ma = static_test_matrix();
|
let mut ma = static_test_matrix();
|
||||||
let new_row = vec![10, 20, 30];
|
let row1_mut = ma.row_mut(1);
|
||||||
ma.row_copy_from_slice(1, &new_row);
|
row1_mut[0] = 20;
|
||||||
assert_eq!(ma.row(1), &[10, 20, 30]);
|
row1_mut[1] = 50;
|
||||||
|
row1_mut[2] = 80;
|
||||||
|
|
||||||
|
assert_eq!(ma.row(1), &[20, 50, 80]);
|
||||||
|
assert_eq!(ma.data(), &[1, 2, 3, 20, 50, 80, 7, 8, 9]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[should_panic(expected = "row index 3 out of bounds for 3 rows")]
|
||||||
|
fn test_row_mut_out_of_bounds() {
|
||||||
|
let mut ma = static_test_matrix();
|
||||||
|
ma.row_mut(3);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user