mirror of
https://github.com/Magnus167/rustframe.git
synced 2025-08-20 04:00:01 +00:00
Refactor DenseNN implementation to enhance activation function handling and improve training process
This commit is contained in:
parent
005c10e816
commit
261d0d7007
@ -1,65 +1,278 @@
|
|||||||
use crate::matrix::{Matrix, SeriesOps};
|
use crate::matrix::{Matrix, SeriesOps};
|
||||||
use crate::compute::activations::{relu, sigmoid, drelu};
|
use crate::compute::activations::{relu, drelu, sigmoid};
|
||||||
use rand::Rng;
|
use rand::prelude::*;
|
||||||
|
|
||||||
|
/// Supported activation functions
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub enum ActivationKind {
|
||||||
|
Relu,
|
||||||
|
Sigmoid,
|
||||||
|
Tanh,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ActivationKind {
|
||||||
|
/// Apply activation elementwise
|
||||||
|
pub fn forward(&self, z: &Matrix<f64>) -> Matrix<f64> {
|
||||||
|
match self {
|
||||||
|
ActivationKind::Relu => relu(z),
|
||||||
|
ActivationKind::Sigmoid => sigmoid(z),
|
||||||
|
ActivationKind::Tanh => z.map(|v| v.tanh()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Compute elementwise derivative w.r.t. pre-activation z
|
||||||
|
pub fn derivative(&self, z: &Matrix<f64>) -> Matrix<f64> {
|
||||||
|
match self {
|
||||||
|
ActivationKind::Relu => drelu(z),
|
||||||
|
ActivationKind::Sigmoid => {
|
||||||
|
let s = sigmoid(z);
|
||||||
|
s.zip(&s, |si, sj| si * (1.0 - sj))
|
||||||
|
}
|
||||||
|
ActivationKind::Tanh => z.map(|v| 1.0 - v.tanh().powi(2)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Weight initialization schemes
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub enum InitializerKind {
|
||||||
|
/// Uniform(-limit .. limit)
|
||||||
|
Uniform(f64),
|
||||||
|
/// Xavier/Glorot uniform
|
||||||
|
Xavier,
|
||||||
|
/// He (Kaiming) uniform
|
||||||
|
He,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl InitializerKind {
|
||||||
|
pub fn initialize(&self, rows: usize, cols: usize) -> Matrix<f64> {
|
||||||
|
let mut rng = rand::rng();
|
||||||
|
let fan_in = rows;
|
||||||
|
let fan_out = cols;
|
||||||
|
let limit = match self {
|
||||||
|
InitializerKind::Uniform(l) => *l,
|
||||||
|
InitializerKind::Xavier => (6.0 / (fan_in + fan_out) as f64).sqrt(),
|
||||||
|
InitializerKind::He => (2.0 / fan_in as f64).sqrt(),
|
||||||
|
};
|
||||||
|
let data = (0..rows * cols)
|
||||||
|
.map(|_| rng.random_range(-limit..limit))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
Matrix::from_vec(data, rows, cols)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Supported losses
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub enum LossKind {
|
||||||
|
/// Mean Squared Error: L = 1/m * sum((y_hat - y)^2)
|
||||||
|
MSE,
|
||||||
|
/// Binary Cross-Entropy: L = -1/m * sum(y*log(y_hat) + (1-y)*log(1-y_hat))
|
||||||
|
BCE,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LossKind {
|
||||||
|
/// Compute gradient dL/dy_hat (before applying activation derivative)
|
||||||
|
pub fn gradient(&self, y_hat: &Matrix<f64>, y: &Matrix<f64>) -> Matrix<f64> {
|
||||||
|
let m = y.rows() as f64;
|
||||||
|
match self {
|
||||||
|
LossKind::MSE => (y_hat - y) * (2.0 / m),
|
||||||
|
LossKind::BCE => (y_hat - y) * (1.0 / m),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Configuration for a dense neural network
|
||||||
|
pub struct DenseNNConfig {
|
||||||
|
pub input_size: usize,
|
||||||
|
pub hidden_layers: Vec<usize>,
|
||||||
|
/// Must have length = hidden_layers.len() + 1
|
||||||
|
pub activations: Vec<ActivationKind>,
|
||||||
|
pub output_size: usize,
|
||||||
|
pub initializer: InitializerKind,
|
||||||
|
pub loss: LossKind,
|
||||||
|
pub learning_rate: f64,
|
||||||
|
pub epochs: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A multi-layer perceptron with full configurability
|
||||||
pub struct DenseNN {
|
pub struct DenseNN {
|
||||||
w1: Matrix<f64>, // (n_in, n_hidden)
|
weights: Vec<Matrix<f64>>,
|
||||||
b1: Matrix<f64>, // (1, n_hidden)
|
biases: Vec<Matrix<f64>>,
|
||||||
w2: Matrix<f64>, // (n_hidden, n_out)
|
activations: Vec<ActivationKind>,
|
||||||
b2: Matrix<f64>, // (1, n_out)
|
loss: LossKind,
|
||||||
|
lr: f64,
|
||||||
|
epochs: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DenseNN {
|
impl DenseNN {
|
||||||
pub fn new(n_in: usize, n_hidden: usize, n_out: usize) -> Self {
|
/// Build a new DenseNN from the given configuration
|
||||||
let mut rng = rand::rng();
|
pub fn new(config: DenseNNConfig) -> Self {
|
||||||
let mut init = |rows, cols| {
|
let mut sizes = vec![config.input_size];
|
||||||
let data = (0..rows * cols)
|
sizes.extend(&config.hidden_layers);
|
||||||
.map(|_| rng.random_range(-1.0..1.0))
|
sizes.push(config.output_size);
|
||||||
.collect::<Vec<_>>();
|
|
||||||
Matrix::from_vec(data, rows, cols)
|
assert_eq!(
|
||||||
};
|
config.activations.len(),
|
||||||
Self {
|
sizes.len() - 1,
|
||||||
w1: init(n_in, n_hidden),
|
"Number of activation functions must match number of layers"
|
||||||
b1: Matrix::zeros(1, n_hidden),
|
);
|
||||||
w2: init(n_hidden, n_out),
|
|
||||||
b2: Matrix::zeros(1, n_out),
|
let mut weights = Vec::with_capacity(sizes.len() - 1);
|
||||||
|
let mut biases = Vec::with_capacity(sizes.len() - 1);
|
||||||
|
|
||||||
|
for i in 0..sizes.len() - 1 {
|
||||||
|
let w = config.initializer.initialize(sizes[i], sizes[i + 1]);
|
||||||
|
let b = Matrix::zeros(1, sizes[i + 1]);
|
||||||
|
weights.push(w);
|
||||||
|
biases.push(b);
|
||||||
|
}
|
||||||
|
|
||||||
|
DenseNN {
|
||||||
|
weights,
|
||||||
|
biases,
|
||||||
|
activations: config.activations,
|
||||||
|
loss: config.loss,
|
||||||
|
lr: config.learning_rate,
|
||||||
|
epochs: config.epochs,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward(&self, x: &Matrix<f64>) -> (Matrix<f64>, Matrix<f64>, Matrix<f64>) {
|
/// Perform a full forward pass, returning pre-activations (z) and activations (a)
|
||||||
// z1 = X·W1 + b1 ; a1 = ReLU(z1)
|
fn forward_full(&self, x: &Matrix<f64>) -> (Vec<Matrix<f64>>, Vec<Matrix<f64>>) {
|
||||||
let z1 = x.dot(&self.w1) + &self.b1;
|
let mut zs = Vec::with_capacity(self.weights.len());
|
||||||
let a1 = relu(&z1);
|
let mut activs = Vec::with_capacity(self.weights.len() + 1);
|
||||||
// z2 = a1·W2 + b2 ; a2 = softmax(z2) (here binary => sigmoid)
|
activs.push(x.clone());
|
||||||
let z2 = a1.dot(&self.w2) + &self.b2;
|
|
||||||
let a2 = sigmoid(&z2); // binary output
|
let mut a = x.clone();
|
||||||
(a1, z2, a2) // keep intermediates for back-prop
|
for (i, (w, b)) in self.weights.iter().zip(self.biases.iter()).enumerate() {
|
||||||
|
let z = &a.dot(w) + &Matrix::repeat_rows(b, a.rows());
|
||||||
|
let a_next = self.activations[i].forward(&z);
|
||||||
|
zs.push(z);
|
||||||
|
activs.push(a_next.clone());
|
||||||
|
a = a_next;
|
||||||
|
}
|
||||||
|
|
||||||
|
(zs, activs)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn train(&mut self, x: &Matrix<f64>, y: &Matrix<f64>, lr: f64, epochs: usize) {
|
/// Train the network on inputs X and targets Y
|
||||||
|
pub fn train(&mut self, x: &Matrix<f64>, y: &Matrix<f64>) {
|
||||||
let m = x.rows() as f64;
|
let m = x.rows() as f64;
|
||||||
for _ in 0..epochs {
|
for _ in 0..self.epochs {
|
||||||
let (a1, _z2, y_hat) = self.forward(x);
|
let (zs, activs) = self.forward_full(x);
|
||||||
|
let y_hat = activs.last().unwrap().clone();
|
||||||
|
|
||||||
// -------- backwards ----------
|
// Initial delta (dL/dz) on output
|
||||||
// dL/da2 = y_hat - y (BCE derivative)
|
let mut delta = match self.loss {
|
||||||
let dz2 = &y_hat - y; // (m, n_out)
|
LossKind::BCE => self.loss.gradient(&y_hat, y),
|
||||||
let dw2 = a1.transpose().dot(&dz2) / m; // (n_h, n_out)
|
LossKind::MSE => {
|
||||||
// let db2 = dz2.sum_vertical() * (1.0 / m); // broadcast ok
|
let grad = self.loss.gradient(&y_hat, y);
|
||||||
let db2 = Matrix::from_vec(dz2.sum_vertical(), 1, dz2.cols()) * (1.0 / m); // (1, n_out)
|
let dz = self.activations.last().unwrap().derivative(zs.last().unwrap());
|
||||||
let da1 = dz2.dot(&self.w2.transpose()); // (m,n_h)
|
grad.zip(&dz, |g, da| g * da)
|
||||||
let dz1 = da1.zip(&a1, |g, act| g * drelu(&Matrix::from_cols(vec![vec![act]])).data()[0]); // (m,n_h)
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// real code: drelu returns Matrix, broadcasting needed; you can optimise.
|
// Backpropagate through layers
|
||||||
|
for l in (0..self.weights.len()).rev() {
|
||||||
|
let a_prev = &activs[l];
|
||||||
|
let dw = a_prev.transpose().dot(&delta) / m;
|
||||||
|
let db = Matrix::from_vec(delta.sum_vertical(), 1, delta.cols()) / m;
|
||||||
|
|
||||||
let dw1 = x.transpose().dot(&dz1) / m; // (n_in,n_h)
|
// Update weights & biases
|
||||||
let db1 = Matrix::from_vec(dz1.sum_vertical(), 1, dz1.cols()) * (1.0 / m); // (1, n_h)
|
self.weights[l] = &self.weights[l] - &(dw * self.lr);
|
||||||
|
self.biases[l] = &self.biases[l] - &(db * self.lr);
|
||||||
|
|
||||||
// -------- update ----------
|
// Propagate delta to previous layer
|
||||||
self.w2 = &self.w2 - &(dw2 * lr);
|
if l > 0 {
|
||||||
self.b2 = &self.b2 - &(db2 * lr);
|
let w_t = self.weights[l].transpose();
|
||||||
self.w1 = &self.w1 - &(dw1 * lr);
|
let da = self.activations[l - 1].derivative(&zs[l - 1]);
|
||||||
self.b1 = &self.b1 - &(db1 * lr);
|
delta = delta.dot(&w_t).zip(&da, |d, a| d * a);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Run a forward pass and return the network's output
|
||||||
|
pub fn predict(&self, x: &Matrix<f64>) -> Matrix<f64> {
|
||||||
|
let mut a = x.clone();
|
||||||
|
for (i, (w, b)) in self.weights.iter().zip(self.biases.iter()).enumerate() {
|
||||||
|
let z = &a.dot(w) + &Matrix::repeat_rows(b, a.rows());
|
||||||
|
a = self.activations[i].forward(&z);
|
||||||
|
}
|
||||||
|
a
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------
|
||||||
|
// Simple tests
|
||||||
|
// ------------------------------
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::matrix::Matrix;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_predict_shape() {
|
||||||
|
let config = DenseNNConfig {
|
||||||
|
input_size: 1,
|
||||||
|
hidden_layers: vec![2],
|
||||||
|
activations: vec![ActivationKind::Relu, ActivationKind::Sigmoid],
|
||||||
|
output_size: 1,
|
||||||
|
initializer: InitializerKind::Uniform(0.1),
|
||||||
|
loss: LossKind::MSE,
|
||||||
|
learning_rate: 0.01,
|
||||||
|
epochs: 0,
|
||||||
|
};
|
||||||
|
let model = DenseNN::new(config);
|
||||||
|
let x = Matrix::from_vec(vec![1.0, 2.0, 3.0], 3, 1);
|
||||||
|
let preds = model.predict(&x);
|
||||||
|
assert_eq!(preds.rows(), 3);
|
||||||
|
assert_eq!(preds.cols(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_train_no_epochs() {
|
||||||
|
let config = DenseNNConfig {
|
||||||
|
input_size: 1,
|
||||||
|
hidden_layers: vec![2],
|
||||||
|
activations: vec![ActivationKind::Relu, ActivationKind::Sigmoid],
|
||||||
|
output_size: 1,
|
||||||
|
initializer: InitializerKind::Uniform(0.1),
|
||||||
|
loss: LossKind::MSE,
|
||||||
|
learning_rate: 0.01,
|
||||||
|
epochs: 0,
|
||||||
|
};
|
||||||
|
let mut model = DenseNN::new(config);
|
||||||
|
let x = Matrix::from_vec(vec![1.0, 2.0], 2, 1);
|
||||||
|
let before = model.predict(&x);
|
||||||
|
model.train(&x, &before);
|
||||||
|
let after = model.predict(&x);
|
||||||
|
for i in 0..before.rows() {
|
||||||
|
assert!((before[(i, 0)] - after[(i, 0)]).abs() < 1e-12);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_dense_nn_step() {
|
||||||
|
let config = DenseNNConfig {
|
||||||
|
input_size: 1,
|
||||||
|
hidden_layers: vec![2],
|
||||||
|
activations: vec![ActivationKind::Relu, ActivationKind::Sigmoid],
|
||||||
|
output_size: 1,
|
||||||
|
initializer: InitializerKind::He,
|
||||||
|
loss: LossKind::BCE,
|
||||||
|
learning_rate: 0.01,
|
||||||
|
epochs: 5000,
|
||||||
|
};
|
||||||
|
let mut model = DenseNN::new(config);
|
||||||
|
let x = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 4, 1);
|
||||||
|
let y = Matrix::from_vec(vec![0.0, 0.0, 1.0, 1.0], 4, 1);
|
||||||
|
model.train(&x, &y);
|
||||||
|
let preds = model.predict(&x);
|
||||||
|
assert!((preds[(0, 0)] - 0.0).abs() < 0.5);
|
||||||
|
assert!((preds[(1, 0)] - 0.0).abs() < 0.5);
|
||||||
|
assert!((preds[(2, 0)] - 1.0).abs() < 0.5);
|
||||||
|
assert!((preds[(3, 0)] - 1.0).abs() < 0.5);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user