mirror of
https://github.com/Magnus167/rustframe.git
synced 2025-08-20 11:00:00 +00:00
Compare commits
9 Commits
c0f82d5ce8
...
bfdd3ca9ec
Author | SHA1 | Date | |
---|---|---|---|
![]() |
bfdd3ca9ec | ||
![]() |
261d0d7007 | ||
![]() |
005c10e816 | ||
![]() |
4c626bf09c | ||
![]() |
ab6d5f9f8f | ||
![]() |
1c8fcc0bad | ||
![]() |
2ca496cfd1 | ||
![]() |
85154a3be0 | ||
![]() |
54a266b630 |
@ -17,14 +17,25 @@ pub fn drelu(x: &Matrix<f64>) -> Matrix<f64> {
|
|||||||
x.map(|v| if v > 0.0 { 1.0 } else { 0.0 })
|
x.map(|v| if v > 0.0 { 1.0 } else { 0.0 })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn leaky_relu(x: &Matrix<f64>) -> Matrix<f64> {
|
||||||
|
x.map(|v| if v > 0.0 { v } else { 0.01 * v })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn dleaky_relu(x: &Matrix<f64>) -> Matrix<f64> {
|
||||||
|
x.map(|v| if v > 0.0 { 1.0 } else { 0.01 })
|
||||||
|
}
|
||||||
|
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
// Helper function to round all elements in a matrix to n decimal places
|
// Helper function to round all elements in a matrix to n decimal places
|
||||||
fn _round_matrix(mat: &Matrix<f64>, decimals: u32) -> Matrix<f64> {
|
fn _round_matrix(mat: &Matrix<f64>, decimals: u32) -> Matrix<f64> {
|
||||||
let factor = 10f64.powi(decimals as i32);
|
let factor = 10f64.powi(decimals as i32);
|
||||||
let rounded: Vec<f64> = mat.to_vec().iter().map(|v| (v * factor).round() / factor).collect();
|
let rounded: Vec<f64> = mat
|
||||||
|
.to_vec()
|
||||||
|
.iter()
|
||||||
|
.map(|v| (v * factor).round() / factor)
|
||||||
|
.collect();
|
||||||
Matrix::from_vec(rounded, mat.rows(), mat.cols())
|
Matrix::from_vec(rounded, mat.rows(), mat.cols())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -36,6 +47,17 @@ mod tests {
|
|||||||
assert_eq!(_round_matrix(&result, 6), _round_matrix(&expected, 6));
|
assert_eq!(_round_matrix(&result, 6), _round_matrix(&expected, 6));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_sigmoid_edge_case() {
|
||||||
|
let x = Matrix::from_vec(vec![-1000.0, 0.0, 1000.0], 3, 1);
|
||||||
|
let expected = Matrix::from_vec(vec![0.0, 0.5, 1.0], 3, 1);
|
||||||
|
let result = sigmoid(&x);
|
||||||
|
|
||||||
|
for (r, e) in result.data().iter().zip(expected.data().iter()) {
|
||||||
|
assert!((r - e).abs() < 1e-6);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_relu() {
|
fn test_relu() {
|
||||||
let x = Matrix::from_vec(vec![-1.0, 0.0, 1.0], 3, 1);
|
let x = Matrix::from_vec(vec![-1.0, 0.0, 1.0], 3, 1);
|
||||||
@ -43,6 +65,13 @@ mod tests {
|
|||||||
assert_eq!(relu(&x), expected);
|
assert_eq!(relu(&x), expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_relu_edge_case() {
|
||||||
|
let x = Matrix::from_vec(vec![-1e-10, 0.0, 1e10], 3, 1);
|
||||||
|
let expected = Matrix::from_vec(vec![0.0, 0.0, 1e10], 3, 1);
|
||||||
|
assert_eq!(relu(&x), expected);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_dsigmoid() {
|
fn test_dsigmoid() {
|
||||||
let y = Matrix::from_vec(vec![0.26894142, 0.5, 0.73105858], 3, 1);
|
let y = Matrix::from_vec(vec![0.26894142, 0.5, 0.73105858], 3, 1);
|
||||||
@ -50,11 +79,57 @@ mod tests {
|
|||||||
let result = dsigmoid(&y);
|
let result = dsigmoid(&y);
|
||||||
assert_eq!(_round_matrix(&result, 6), _round_matrix(&expected, 6));
|
assert_eq!(_round_matrix(&result, 6), _round_matrix(&expected, 6));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_dsigmoid_edge_case() {
|
||||||
|
let y = Matrix::from_vec(vec![0.0, 0.5, 1.0], 3, 1); // Assume these are outputs from sigmoid(x)
|
||||||
|
let expected = Matrix::from_vec(vec![0.0, 0.25, 0.0], 3, 1);
|
||||||
|
let result = dsigmoid(&y);
|
||||||
|
|
||||||
|
for (r, e) in result.data().iter().zip(expected.data().iter()) {
|
||||||
|
assert!((r - e).abs() < 1e-6);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_drelu() {
|
fn test_drelu() {
|
||||||
let x = Matrix::from_vec(vec![-1.0, 0.0, 1.0], 3, 1);
|
let x = Matrix::from_vec(vec![-1.0, 0.0, 1.0], 3, 1);
|
||||||
let expected = Matrix::from_vec(vec![0.0, 0.0, 1.0], 3, 1);
|
let expected = Matrix::from_vec(vec![0.0, 0.0, 1.0], 3, 1);
|
||||||
assert_eq!(drelu(&x), expected);
|
assert_eq!(drelu(&x), expected);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
#[test]
|
||||||
|
fn test_drelu_edge_case() {
|
||||||
|
let x = Matrix::from_vec(vec![-1e-10, 0.0, 1e10], 3, 1);
|
||||||
|
let expected = Matrix::from_vec(vec![0.0, 0.0, 1.0], 3, 1);
|
||||||
|
assert_eq!(drelu(&x), expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_leaky_relu() {
|
||||||
|
let x = Matrix::from_vec(vec![-1.0, 0.0, 1.0], 3, 1);
|
||||||
|
let expected = Matrix::from_vec(vec![-0.01, 0.0, 1.0], 3, 1);
|
||||||
|
assert_eq!(leaky_relu(&x), expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_leaky_relu_edge_case() {
|
||||||
|
let x = Matrix::from_vec(vec![-1e-10, 0.0, 1e10], 3, 1);
|
||||||
|
let expected = Matrix::from_vec(vec![-1e-12, 0.0, 1e10], 3, 1);
|
||||||
|
assert_eq!(leaky_relu(&x), expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_dleaky_relu() {
|
||||||
|
let x = Matrix::from_vec(vec![-1.0, 0.0, 1.0], 3, 1);
|
||||||
|
let expected = Matrix::from_vec(vec![0.01, 0.01, 1.0], 3, 1);
|
||||||
|
assert_eq!(dleaky_relu(&x), expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_dleaky_relu_edge_case() {
|
||||||
|
let x = Matrix::from_vec(vec![-1e-10, 0.0, 1e10], 3, 1);
|
||||||
|
let expected = Matrix::from_vec(vec![0.01, 0.01, 1.0], 3, 1);
|
||||||
|
assert_eq!(dleaky_relu(&x), expected);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -1,65 +1,278 @@
|
|||||||
use crate::matrix::{Matrix, SeriesOps};
|
use crate::matrix::{Matrix, SeriesOps};
|
||||||
use crate::compute::activations::{relu, sigmoid, drelu};
|
use crate::compute::activations::{relu, drelu, sigmoid};
|
||||||
use rand::Rng;
|
use rand::prelude::*;
|
||||||
|
|
||||||
|
/// Supported activation functions
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub enum ActivationKind {
|
||||||
|
Relu,
|
||||||
|
Sigmoid,
|
||||||
|
Tanh,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ActivationKind {
|
||||||
|
/// Apply activation elementwise
|
||||||
|
pub fn forward(&self, z: &Matrix<f64>) -> Matrix<f64> {
|
||||||
|
match self {
|
||||||
|
ActivationKind::Relu => relu(z),
|
||||||
|
ActivationKind::Sigmoid => sigmoid(z),
|
||||||
|
ActivationKind::Tanh => z.map(|v| v.tanh()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Compute elementwise derivative w.r.t. pre-activation z
|
||||||
|
pub fn derivative(&self, z: &Matrix<f64>) -> Matrix<f64> {
|
||||||
|
match self {
|
||||||
|
ActivationKind::Relu => drelu(z),
|
||||||
|
ActivationKind::Sigmoid => {
|
||||||
|
let s = sigmoid(z);
|
||||||
|
s.zip(&s, |si, sj| si * (1.0 - sj))
|
||||||
|
}
|
||||||
|
ActivationKind::Tanh => z.map(|v| 1.0 - v.tanh().powi(2)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Weight initialization schemes
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub enum InitializerKind {
|
||||||
|
/// Uniform(-limit .. limit)
|
||||||
|
Uniform(f64),
|
||||||
|
/// Xavier/Glorot uniform
|
||||||
|
Xavier,
|
||||||
|
/// He (Kaiming) uniform
|
||||||
|
He,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl InitializerKind {
|
||||||
|
pub fn initialize(&self, rows: usize, cols: usize) -> Matrix<f64> {
|
||||||
|
let mut rng = rand::rng();
|
||||||
|
let fan_in = rows;
|
||||||
|
let fan_out = cols;
|
||||||
|
let limit = match self {
|
||||||
|
InitializerKind::Uniform(l) => *l,
|
||||||
|
InitializerKind::Xavier => (6.0 / (fan_in + fan_out) as f64).sqrt(),
|
||||||
|
InitializerKind::He => (2.0 / fan_in as f64).sqrt(),
|
||||||
|
};
|
||||||
|
let data = (0..rows * cols)
|
||||||
|
.map(|_| rng.random_range(-limit..limit))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
Matrix::from_vec(data, rows, cols)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Supported losses
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub enum LossKind {
|
||||||
|
/// Mean Squared Error: L = 1/m * sum((y_hat - y)^2)
|
||||||
|
MSE,
|
||||||
|
/// Binary Cross-Entropy: L = -1/m * sum(y*log(y_hat) + (1-y)*log(1-y_hat))
|
||||||
|
BCE,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LossKind {
|
||||||
|
/// Compute gradient dL/dy_hat (before applying activation derivative)
|
||||||
|
pub fn gradient(&self, y_hat: &Matrix<f64>, y: &Matrix<f64>) -> Matrix<f64> {
|
||||||
|
let m = y.rows() as f64;
|
||||||
|
match self {
|
||||||
|
LossKind::MSE => (y_hat - y) * (2.0 / m),
|
||||||
|
LossKind::BCE => (y_hat - y) * (1.0 / m),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Configuration for a dense neural network
|
||||||
|
pub struct DenseNNConfig {
|
||||||
|
pub input_size: usize,
|
||||||
|
pub hidden_layers: Vec<usize>,
|
||||||
|
/// Must have length = hidden_layers.len() + 1
|
||||||
|
pub activations: Vec<ActivationKind>,
|
||||||
|
pub output_size: usize,
|
||||||
|
pub initializer: InitializerKind,
|
||||||
|
pub loss: LossKind,
|
||||||
|
pub learning_rate: f64,
|
||||||
|
pub epochs: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A multi-layer perceptron with full configurability
|
||||||
pub struct DenseNN {
|
pub struct DenseNN {
|
||||||
w1: Matrix<f64>, // (n_in, n_hidden)
|
weights: Vec<Matrix<f64>>,
|
||||||
b1: Matrix<f64>, // (1, n_hidden)
|
biases: Vec<Matrix<f64>>,
|
||||||
w2: Matrix<f64>, // (n_hidden, n_out)
|
activations: Vec<ActivationKind>,
|
||||||
b2: Matrix<f64>, // (1, n_out)
|
loss: LossKind,
|
||||||
|
lr: f64,
|
||||||
|
epochs: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DenseNN {
|
impl DenseNN {
|
||||||
pub fn new(n_in: usize, n_hidden: usize, n_out: usize) -> Self {
|
/// Build a new DenseNN from the given configuration
|
||||||
let mut rng = rand::rng();
|
pub fn new(config: DenseNNConfig) -> Self {
|
||||||
let mut init = |rows, cols| {
|
let mut sizes = vec![config.input_size];
|
||||||
let data = (0..rows * cols)
|
sizes.extend(&config.hidden_layers);
|
||||||
.map(|_| rng.random_range(-1.0..1.0))
|
sizes.push(config.output_size);
|
||||||
.collect::<Vec<_>>();
|
|
||||||
Matrix::from_vec(data, rows, cols)
|
assert_eq!(
|
||||||
};
|
config.activations.len(),
|
||||||
Self {
|
sizes.len() - 1,
|
||||||
w1: init(n_in, n_hidden),
|
"Number of activation functions must match number of layers"
|
||||||
b1: Matrix::zeros(1, n_hidden),
|
);
|
||||||
w2: init(n_hidden, n_out),
|
|
||||||
b2: Matrix::zeros(1, n_out),
|
let mut weights = Vec::with_capacity(sizes.len() - 1);
|
||||||
|
let mut biases = Vec::with_capacity(sizes.len() - 1);
|
||||||
|
|
||||||
|
for i in 0..sizes.len() - 1 {
|
||||||
|
let w = config.initializer.initialize(sizes[i], sizes[i + 1]);
|
||||||
|
let b = Matrix::zeros(1, sizes[i + 1]);
|
||||||
|
weights.push(w);
|
||||||
|
biases.push(b);
|
||||||
|
}
|
||||||
|
|
||||||
|
DenseNN {
|
||||||
|
weights,
|
||||||
|
biases,
|
||||||
|
activations: config.activations,
|
||||||
|
loss: config.loss,
|
||||||
|
lr: config.learning_rate,
|
||||||
|
epochs: config.epochs,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward(&self, x: &Matrix<f64>) -> (Matrix<f64>, Matrix<f64>, Matrix<f64>) {
|
/// Perform a full forward pass, returning pre-activations (z) and activations (a)
|
||||||
// z1 = X·W1 + b1 ; a1 = ReLU(z1)
|
fn forward_full(&self, x: &Matrix<f64>) -> (Vec<Matrix<f64>>, Vec<Matrix<f64>>) {
|
||||||
let z1 = x.dot(&self.w1) + &self.b1;
|
let mut zs = Vec::with_capacity(self.weights.len());
|
||||||
let a1 = relu(&z1);
|
let mut activs = Vec::with_capacity(self.weights.len() + 1);
|
||||||
// z2 = a1·W2 + b2 ; a2 = softmax(z2) (here binary => sigmoid)
|
activs.push(x.clone());
|
||||||
let z2 = a1.dot(&self.w2) + &self.b2;
|
|
||||||
let a2 = sigmoid(&z2); // binary output
|
let mut a = x.clone();
|
||||||
(a1, z2, a2) // keep intermediates for back-prop
|
for (i, (w, b)) in self.weights.iter().zip(self.biases.iter()).enumerate() {
|
||||||
|
let z = &a.dot(w) + &Matrix::repeat_rows(b, a.rows());
|
||||||
|
let a_next = self.activations[i].forward(&z);
|
||||||
|
zs.push(z);
|
||||||
|
activs.push(a_next.clone());
|
||||||
|
a = a_next;
|
||||||
|
}
|
||||||
|
|
||||||
|
(zs, activs)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn train(&mut self, x: &Matrix<f64>, y: &Matrix<f64>, lr: f64, epochs: usize) {
|
/// Train the network on inputs X and targets Y
|
||||||
|
pub fn train(&mut self, x: &Matrix<f64>, y: &Matrix<f64>) {
|
||||||
let m = x.rows() as f64;
|
let m = x.rows() as f64;
|
||||||
for _ in 0..epochs {
|
for _ in 0..self.epochs {
|
||||||
let (a1, _z2, y_hat) = self.forward(x);
|
let (zs, activs) = self.forward_full(x);
|
||||||
|
let y_hat = activs.last().unwrap().clone();
|
||||||
|
|
||||||
// -------- backwards ----------
|
// Initial delta (dL/dz) on output
|
||||||
// dL/da2 = y_hat - y (BCE derivative)
|
let mut delta = match self.loss {
|
||||||
let dz2 = &y_hat - y; // (m, n_out)
|
LossKind::BCE => self.loss.gradient(&y_hat, y),
|
||||||
let dw2 = a1.transpose().dot(&dz2) / m; // (n_h, n_out)
|
LossKind::MSE => {
|
||||||
// let db2 = dz2.sum_vertical() * (1.0 / m); // broadcast ok
|
let grad = self.loss.gradient(&y_hat, y);
|
||||||
let db2 = Matrix::from_vec(dz2.sum_vertical(), 1, dz2.cols()) * (1.0 / m); // (1, n_out)
|
let dz = self.activations.last().unwrap().derivative(zs.last().unwrap());
|
||||||
let da1 = dz2.dot(&self.w2.transpose()); // (m,n_h)
|
grad.zip(&dz, |g, da| g * da)
|
||||||
let dz1 = da1.zip(&a1, |g, act| g * drelu(&Matrix::from_cols(vec![vec![act]])).data()[0]); // (m,n_h)
|
}
|
||||||
|
};
|
||||||
// real code: drelu returns Matrix, broadcasting needed; you can optimise.
|
|
||||||
|
|
||||||
let dw1 = x.transpose().dot(&dz1) / m; // (n_in,n_h)
|
// Backpropagate through layers
|
||||||
let db1 = Matrix::from_vec(dz1.sum_vertical(), 1, dz1.cols()) * (1.0 / m); // (1, n_h)
|
for l in (0..self.weights.len()).rev() {
|
||||||
|
let a_prev = &activs[l];
|
||||||
|
let dw = a_prev.transpose().dot(&delta) / m;
|
||||||
|
let db = Matrix::from_vec(delta.sum_vertical(), 1, delta.cols()) / m;
|
||||||
|
|
||||||
// -------- update ----------
|
// Update weights & biases
|
||||||
self.w2 = &self.w2 - &(dw2 * lr);
|
self.weights[l] = &self.weights[l] - &(dw * self.lr);
|
||||||
self.b2 = &self.b2 - &(db2 * lr);
|
self.biases[l] = &self.biases[l] - &(db * self.lr);
|
||||||
self.w1 = &self.w1 - &(dw1 * lr);
|
|
||||||
self.b1 = &self.b1 - &(db1 * lr);
|
// Propagate delta to previous layer
|
||||||
|
if l > 0 {
|
||||||
|
let w_t = self.weights[l].transpose();
|
||||||
|
let da = self.activations[l - 1].derivative(&zs[l - 1]);
|
||||||
|
delta = delta.dot(&w_t).zip(&da, |d, a| d * a);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Run a forward pass and return the network's output
|
||||||
|
pub fn predict(&self, x: &Matrix<f64>) -> Matrix<f64> {
|
||||||
|
let mut a = x.clone();
|
||||||
|
for (i, (w, b)) in self.weights.iter().zip(self.biases.iter()).enumerate() {
|
||||||
|
let z = &a.dot(w) + &Matrix::repeat_rows(b, a.rows());
|
||||||
|
a = self.activations[i].forward(&z);
|
||||||
|
}
|
||||||
|
a
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------
|
||||||
|
// Simple tests
|
||||||
|
// ------------------------------
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::matrix::Matrix;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_predict_shape() {
|
||||||
|
let config = DenseNNConfig {
|
||||||
|
input_size: 1,
|
||||||
|
hidden_layers: vec![2],
|
||||||
|
activations: vec![ActivationKind::Relu, ActivationKind::Sigmoid],
|
||||||
|
output_size: 1,
|
||||||
|
initializer: InitializerKind::Uniform(0.1),
|
||||||
|
loss: LossKind::MSE,
|
||||||
|
learning_rate: 0.01,
|
||||||
|
epochs: 0,
|
||||||
|
};
|
||||||
|
let model = DenseNN::new(config);
|
||||||
|
let x = Matrix::from_vec(vec![1.0, 2.0, 3.0], 3, 1);
|
||||||
|
let preds = model.predict(&x);
|
||||||
|
assert_eq!(preds.rows(), 3);
|
||||||
|
assert_eq!(preds.cols(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_train_no_epochs() {
|
||||||
|
let config = DenseNNConfig {
|
||||||
|
input_size: 1,
|
||||||
|
hidden_layers: vec![2],
|
||||||
|
activations: vec![ActivationKind::Relu, ActivationKind::Sigmoid],
|
||||||
|
output_size: 1,
|
||||||
|
initializer: InitializerKind::Uniform(0.1),
|
||||||
|
loss: LossKind::MSE,
|
||||||
|
learning_rate: 0.01,
|
||||||
|
epochs: 0,
|
||||||
|
};
|
||||||
|
let mut model = DenseNN::new(config);
|
||||||
|
let x = Matrix::from_vec(vec![1.0, 2.0], 2, 1);
|
||||||
|
let before = model.predict(&x);
|
||||||
|
model.train(&x, &before);
|
||||||
|
let after = model.predict(&x);
|
||||||
|
for i in 0..before.rows() {
|
||||||
|
assert!((before[(i, 0)] - after[(i, 0)]).abs() < 1e-12);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_dense_nn_step() {
|
||||||
|
let config = DenseNNConfig {
|
||||||
|
input_size: 1,
|
||||||
|
hidden_layers: vec![2],
|
||||||
|
activations: vec![ActivationKind::Relu, ActivationKind::Sigmoid],
|
||||||
|
output_size: 1,
|
||||||
|
initializer: InitializerKind::He,
|
||||||
|
loss: LossKind::BCE,
|
||||||
|
learning_rate: 0.01,
|
||||||
|
epochs: 5000,
|
||||||
|
};
|
||||||
|
let mut model = DenseNN::new(config);
|
||||||
|
let x = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 4, 1);
|
||||||
|
let y = Matrix::from_vec(vec![0.0, 0.0, 1.0, 1.0], 4, 1);
|
||||||
|
model.train(&x, &y);
|
||||||
|
let preds = model.predict(&x);
|
||||||
|
assert!((preds[(0, 0)] - 0.0).abs() < 0.5);
|
||||||
|
assert!((preds[(1, 0)] - 0.0).abs() < 0.5);
|
||||||
|
assert!((preds[(2, 0)] - 1.0).abs() < 0.5);
|
||||||
|
assert!((preds[(3, 0)] - 1.0).abs() < 0.5);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -34,10 +34,10 @@ impl LinReg {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
|
||||||
use super::LinReg;
|
use super::*;
|
||||||
use crate::matrix::{Matrix};
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_linreg_fit_predict() {
|
fn test_linreg_fit_predict() {
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
use crate::matrix::{Matrix, SeriesOps};
|
|
||||||
use crate::compute::activations::sigmoid;
|
use crate::compute::activations::sigmoid;
|
||||||
|
use crate::matrix::{Matrix, SeriesOps};
|
||||||
|
|
||||||
pub struct LogReg {
|
pub struct LogReg {
|
||||||
w: Matrix<f64>,
|
w: Matrix<f64>,
|
||||||
@ -15,14 +15,14 @@ impl LogReg {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn predict_proba(&self, x: &Matrix<f64>) -> Matrix<f64> {
|
pub fn predict_proba(&self, x: &Matrix<f64>) -> Matrix<f64> {
|
||||||
sigmoid(&(x.dot(&self.w) + self.b)) // σ(Xw + b)
|
sigmoid(&(x.dot(&self.w) + self.b)) // σ(Xw + b)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn fit(&mut self, x: &Matrix<f64>, y: &Matrix<f64>, lr: f64, epochs: usize) {
|
pub fn fit(&mut self, x: &Matrix<f64>, y: &Matrix<f64>, lr: f64, epochs: usize) {
|
||||||
let m = x.rows() as f64;
|
let m = x.rows() as f64;
|
||||||
for _ in 0..epochs {
|
for _ in 0..epochs {
|
||||||
let p = self.predict_proba(x); // shape (m,1)
|
let p = self.predict_proba(x); // shape (m,1)
|
||||||
let err = &p - y; // derivative of BCE wrt pre-sigmoid
|
let err = &p - y; // derivative of BCE wrt pre-sigmoid
|
||||||
let grad_w = x.transpose().dot(&err) / m;
|
let grad_w = x.transpose().dot(&err) / m;
|
||||||
let grad_b = err.sum_vertical().iter().sum::<f64>() / m;
|
let grad_b = err.sum_vertical().iter().sum::<f64>() / m;
|
||||||
self.w = &self.w - &(grad_w * lr);
|
self.w = &self.w - &(grad_w * lr);
|
||||||
@ -31,6 +31,25 @@ impl LogReg {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn predict(&self, x: &Matrix<f64>) -> Matrix<f64> {
|
pub fn predict(&self, x: &Matrix<f64>) -> Matrix<f64> {
|
||||||
self.predict_proba(x).map(|p| if p >= 0.5 { 1.0 } else { 0.0 })
|
self.predict_proba(x)
|
||||||
|
.map(|p| if p >= 0.5 { 1.0 } else { 0.0 })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_logreg_fit_predict() {
|
||||||
|
let x = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 4, 1);
|
||||||
|
let y = Matrix::from_vec(vec![0.0, 0.0, 1.0, 1.0], 4, 1);
|
||||||
|
let mut model = LogReg::new(1);
|
||||||
|
model.fit(&x, &y, 0.01, 10000);
|
||||||
|
let preds = model.predict(&x);
|
||||||
|
assert_eq!(preds[(0, 0)], 0.0);
|
||||||
|
assert_eq!(preds[(1, 0)], 0.0);
|
||||||
|
assert_eq!(preds[(2, 0)], 1.0);
|
||||||
|
assert_eq!(preds[(3, 0)], 1.0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -89,6 +89,10 @@ impl<T: Clone> Matrix<T> {
|
|||||||
self.cols
|
self.cols
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn shape(&self) -> (usize, usize) {
|
||||||
|
(self.rows, self.cols)
|
||||||
|
}
|
||||||
|
|
||||||
/// Get element reference (immutable). Panics on out-of-bounds.
|
/// Get element reference (immutable). Panics on out-of-bounds.
|
||||||
pub fn get(&self, r: usize, c: usize) -> &T {
|
pub fn get(&self, r: usize, c: usize) -> &T {
|
||||||
&self[(r, c)]
|
&self[(r, c)]
|
||||||
@ -323,6 +327,21 @@ impl<T: Clone> Matrix<T> {
|
|||||||
self.data = new_data;
|
self.data = new_data;
|
||||||
self.rows = new_rows;
|
self.rows = new_rows;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Return a new matrix where row 0 of `self` is repeated `n` times.
|
||||||
|
pub fn repeat_rows(&self, n: usize) -> Matrix<T>
|
||||||
|
where
|
||||||
|
T: Clone,
|
||||||
|
{
|
||||||
|
let mut data = Vec::with_capacity(n * self.cols());
|
||||||
|
let zeroth_row = self.row(0);
|
||||||
|
for value in &zeroth_row {
|
||||||
|
for _ in 0..n {
|
||||||
|
data.push(value.clone()); // Clone each element
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Matrix::from_vec(data, n, self.cols)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Matrix<f64> {
|
impl Matrix<f64> {
|
||||||
@ -1153,6 +1172,25 @@ mod tests {
|
|||||||
assert_eq!(ma.row(2), &[3, 6, 9]);
|
assert_eq!(ma.row(2), &[3, 6, 9]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_shape() {
|
||||||
|
let ma = static_test_matrix_2x4();
|
||||||
|
assert_eq!(ma.shape(), (2, 4));
|
||||||
|
assert_eq!(ma.rows(), 2);
|
||||||
|
assert_eq!(ma.cols(), 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_repeat_rows() {
|
||||||
|
let ma = static_test_matrix();
|
||||||
|
// Returns a new matrix where row 0 of `self` is repeated `n` times.
|
||||||
|
let repeated = ma.repeat_rows(3);
|
||||||
|
// assert all rows are equal to the first row
|
||||||
|
for r in 0..repeated.rows() {
|
||||||
|
assert_eq!(repeated.row(r), ma.row(0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[should_panic(expected = "row index 3 out of bounds for 3 rows")]
|
#[should_panic(expected = "row index 3 out of bounds for 3 rows")]
|
||||||
fn test_row_out_of_bounds() {
|
fn test_row_out_of_bounds() {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user