Refactor check_weights_signs_lengths to enforce f64 types and validate non-zero weights and signs

This commit is contained in:
Palash Tyagi 2025-04-10 00:24:00 +01:00
parent 1a7d3c1491
commit de62daaf8b

View File

@ -319,9 +319,9 @@ fn form_weights_and_signs_map(
Ok(weights_map)
}
fn check_weights_signs_lengths<T>(
weights_vec: Vec<T>,
signs_vec: Vec<T>,
fn check_weights_signs_lengths(
weights_vec: Vec<f64>,
signs_vec: Vec<f64>,
_agg_xcats_for_cid: bool,
agg_targ_len: usize,
) -> Result<(), Box<dyn std::error::Error>> {
@ -330,11 +330,22 @@ fn check_weights_signs_lengths<T>(
true => "xcats",
false => "cids",
};
for (vx, vname) in vec![(weights_vec.len(), "weights"), (signs_vec.len(), "signs")] {
if vx != agg_targ_len {
for (vx, vname) in vec![
(weights_vec.clone(), "weights"),
(signs_vec.clone(), "signs"),
] {
for (i, v) in vx.iter().enumerate() {
if *v == 0.0 {
return Err(format!("The {} at index {} is 0.0", vname, i).into());
}
}
if vx.len() != agg_targ_len {
return Err(format!(
"The length of {} ({}) does not match the length of {} ({})",
vname, vx, agg_targ, agg_targ_len
vname,
vx.len(),
agg_targ,
agg_targ_len
)
.into());
}