mirror of
https://github.com/Magnus167/rustframe.git
synced 2025-08-20 04:00:01 +00:00
Refactor test assertions to improve readability by removing error messages from assert macros
This commit is contained in:
parent
9182ab9fca
commit
7b0d34384a
@ -264,10 +264,8 @@ mod tests {
|
|||||||
|
|
||||||
for i in 0..before.rows() {
|
for i in 0..before.rows() {
|
||||||
for j in 0..before.cols() {
|
for j in 0..before.cols() {
|
||||||
assert!(
|
// "prediction changed despite 0 epochs"
|
||||||
(before[(i, j)] - after[(i, j)]).abs() < 1e-12,
|
assert!((before[(i, j)] - after[(i, j)]).abs() < 1e-12);
|
||||||
"prediction changed despite 0 epochs"
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -330,12 +328,8 @@ mod tests {
|
|||||||
let after_preds = model.predict(&x);
|
let after_preds = model.predict(&x);
|
||||||
let after_loss = mse_loss(&after_preds, &y);
|
let after_loss = mse_loss(&after_preds, &y);
|
||||||
|
|
||||||
assert!(
|
// MSE did not decrease (before: {}, after: {})
|
||||||
after_loss < before_loss,
|
assert!(after_loss < before_loss);
|
||||||
"MSE did not decrease (before: {}, after: {})",
|
|
||||||
before_loss,
|
|
||||||
after_loss
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -346,12 +340,8 @@ mod tests {
|
|||||||
|
|
||||||
for i in 0..input.rows() {
|
for i in 0..input.rows() {
|
||||||
for j in 0..input.cols() {
|
for j in 0..input.cols() {
|
||||||
assert!(
|
// Tanh forward output mismatch at ({}, {})
|
||||||
(output[(i, j)] - expected[(i, j)]).abs() < 1e-9,
|
assert!((output[(i, j)] - expected[(i, j)]).abs() < 1e-9);
|
||||||
"Tanh forward output mismatch at ({}, {})",
|
|
||||||
i,
|
|
||||||
j
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -364,12 +354,8 @@ mod tests {
|
|||||||
|
|
||||||
for i in 0..input.rows() {
|
for i in 0..input.rows() {
|
||||||
for j in 0..input.cols() {
|
for j in 0..input.cols() {
|
||||||
assert!(
|
// "ReLU derivative output mismatch at ({}, {})"
|
||||||
(output[(i, j)] - expected[(i, j)]).abs() < 1e-9,
|
assert!((output[(i, j)] - expected[(i, j)]).abs() < 1e-9);
|
||||||
"ReLU derivative output mismatch at ({}, {})",
|
|
||||||
i,
|
|
||||||
j
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -382,12 +368,8 @@ mod tests {
|
|||||||
|
|
||||||
for i in 0..input.rows() {
|
for i in 0..input.rows() {
|
||||||
for j in 0..input.cols() {
|
for j in 0..input.cols() {
|
||||||
assert!(
|
// "Tanh derivative output mismatch at ({}, {})"
|
||||||
(output[(i, j)] - expected[(i, j)]).abs() < 1e-9,
|
assert!((output[(i, j)] - expected[(i, j)]).abs() < 1e-9);
|
||||||
"Tanh derivative output mismatch at ({}, {})",
|
|
||||||
i,
|
|
||||||
j
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -404,10 +386,8 @@ mod tests {
|
|||||||
assert_eq!(matrix.cols(), cols);
|
assert_eq!(matrix.cols(), cols);
|
||||||
|
|
||||||
for val in matrix.data() {
|
for val in matrix.data() {
|
||||||
assert!(
|
// Xavier initialized value out of range
|
||||||
*val >= -limit && *val <= limit,
|
assert!(*val >= -limit && *val <= limit);
|
||||||
"Xavier initialized value out of range"
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -423,10 +403,8 @@ mod tests {
|
|||||||
assert_eq!(matrix.cols(), cols);
|
assert_eq!(matrix.cols(), cols);
|
||||||
|
|
||||||
for val in matrix.data() {
|
for val in matrix.data() {
|
||||||
assert!(
|
// He initialized value out of range
|
||||||
*val >= -limit && *val <= limit,
|
assert!(*val >= -limit && *val <= limit);
|
||||||
"He initialized value out of range"
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -442,12 +420,8 @@ mod tests {
|
|||||||
|
|
||||||
for i in 0..output_gradient.rows() {
|
for i in 0..output_gradient.rows() {
|
||||||
for j in 0..output_gradient.cols() {
|
for j in 0..output_gradient.cols() {
|
||||||
assert!(
|
// BCE gradient output mismatch at ({}, {})
|
||||||
(output_gradient[(i, j)] - expected_gradient[(i, j)]).abs() < 1e-9,
|
assert!((output_gradient[(i, j)] - expected_gradient[(i, j)]).abs() < 1e-9);
|
||||||
"BCE gradient output mismatch at ({}, {})",
|
|
||||||
i,
|
|
||||||
j
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -489,12 +463,8 @@ mod tests {
|
|||||||
.iter()
|
.iter()
|
||||||
.sum::<f64>();
|
.sum::<f64>();
|
||||||
|
|
||||||
assert!(
|
// BCE did not decrease (before: {}, after: {})
|
||||||
after_loss < before_loss,
|
assert!(after_loss < before_loss,);
|
||||||
"BCE did not decrease (before: {}, after: {})",
|
|
||||||
before_loss,
|
|
||||||
after_loss
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -525,21 +495,15 @@ mod tests {
|
|||||||
|
|
||||||
// Verify that weights and biases of both layers have changed,
|
// Verify that weights and biases of both layers have changed,
|
||||||
// implying delta propagation occurred for l > 0
|
// implying delta propagation occurred for l > 0
|
||||||
assert!(
|
|
||||||
model.weights[0] != initial_weights_l0,
|
|
||||||
"Weights of first layer did not change, delta propagation might not have occurred"
|
// Weights of first layer did not change, delta propagation might not have occurred
|
||||||
);
|
assert!(model.weights[0] != initial_weights_l0);
|
||||||
assert!(
|
// Biases of first layer did not change, delta propagation might not have occurred
|
||||||
model.biases[0] != initial_biases_l0,
|
assert!(model.biases[0] != initial_biases_l0);
|
||||||
"Biases of first layer did not change, delta propagation might not have occurred"
|
// Weights of second layer did not change
|
||||||
);
|
assert!(model.weights[1] != initial_weights_l1);
|
||||||
assert!(
|
// Biases of second layer did not change
|
||||||
model.weights[1] != initial_weights_l1,
|
assert!(model.biases[1] != initial_biases_l1);
|
||||||
"Weights of second layer did not change"
|
|
||||||
);
|
|
||||||
assert!(
|
|
||||||
model.biases[1] != initial_biases_l1,
|
|
||||||
"Biases of second layer did not change"
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user