applied formatting

This commit is contained in:
Palash Tyagi
2025-07-12 00:56:09 +01:00
parent 9b08eaeb35
commit de18d8e010
9 changed files with 45 additions and 33 deletions

View File

@@ -349,7 +349,8 @@ mod tests {
assert!(
(output[(i, j)] - expected[(i, j)]).abs() < 1e-9,
"Tanh forward output mismatch at ({}, {})",
i, j
i,
j
);
}
}
@@ -366,7 +367,8 @@ mod tests {
assert!(
(output[(i, j)] - expected[(i, j)]).abs() < 1e-9,
"ReLU derivative output mismatch at ({}, {})",
i, j
i,
j
);
}
}
@@ -383,7 +385,8 @@ mod tests {
assert!(
(output[(i, j)] - expected[(i, j)]).abs() < 1e-9,
"Tanh derivative output mismatch at ({}, {})",
i, j
i,
j
);
}
}
@@ -401,7 +404,10 @@ mod tests {
assert_eq!(matrix.cols(), cols);
for val in matrix.data() {
assert!(*val >= -limit && *val <= limit, "Xavier initialized value out of range");
assert!(
*val >= -limit && *val <= limit,
"Xavier initialized value out of range"
);
}
}
@@ -417,7 +423,10 @@ mod tests {
assert_eq!(matrix.cols(), cols);
for val in matrix.data() {
assert!(*val >= -limit && *val <= limit, "He initialized value out of range");
assert!(
*val >= -limit && *val <= limit,
"He initialized value out of range"
);
}
}
@@ -436,7 +445,8 @@ mod tests {
assert!(
(output_gradient[(i, j)] - expected_gradient[(i, j)]).abs() < 1e-9,
"BCE gradient output mismatch at ({}, {})",
i, j
i,
j
);
}
}
@@ -462,16 +472,22 @@ mod tests {
let before_preds = model.predict(&x);
// BCE loss calculation for testing
let before_loss = -1.0 / (y.rows() as f64) * before_preds.zip(&y, |yh, yv| {
yv * yh.ln() + (1.0 - yv) * (1.0 - yh).ln()
}).data().iter().sum::<f64>();
let before_loss = -1.0 / (y.rows() as f64)
* before_preds
.zip(&y, |yh, yv| yv * yh.ln() + (1.0 - yv) * (1.0 - yh).ln())
.data()
.iter()
.sum::<f64>();
model.train(&x, &y);
let after_preds = model.predict(&x);
let after_loss = -1.0 / (y.rows() as f64) * after_preds.zip(&y, |yh, yv| {
yv * yh.ln() + (1.0 - yv) * (1.0 - yh).ln()
}).data().iter().sum::<f64>();
let after_loss = -1.0 / (y.rows() as f64)
* after_preds
.zip(&y, |yh, yv| yv * yh.ln() + (1.0 - yv) * (1.0 - yh).ln())
.data()
.iter()
.sum::<f64>();
assert!(
after_loss < before_loss,