mirror of
https://github.com/Magnus167/msyrs.git
synced 2025-08-20 07:20:01 +00:00
Refactor check_weights_signs_lengths to enforce f64 types and validate non-zero weights and signs
This commit is contained in:
parent
1a7d3c1491
commit
de62daaf8b
@ -319,9 +319,9 @@ fn form_weights_and_signs_map(
|
||||
Ok(weights_map)
|
||||
}
|
||||
|
||||
fn check_weights_signs_lengths<T>(
|
||||
weights_vec: Vec<T>,
|
||||
signs_vec: Vec<T>,
|
||||
fn check_weights_signs_lengths(
|
||||
weights_vec: Vec<f64>,
|
||||
signs_vec: Vec<f64>,
|
||||
_agg_xcats_for_cid: bool,
|
||||
agg_targ_len: usize,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
@ -330,11 +330,22 @@ fn check_weights_signs_lengths<T>(
|
||||
true => "xcats",
|
||||
false => "cids",
|
||||
};
|
||||
for (vx, vname) in vec![(weights_vec.len(), "weights"), (signs_vec.len(), "signs")] {
|
||||
if vx != agg_targ_len {
|
||||
for (vx, vname) in vec![
|
||||
(weights_vec.clone(), "weights"),
|
||||
(signs_vec.clone(), "signs"),
|
||||
] {
|
||||
for (i, v) in vx.iter().enumerate() {
|
||||
if *v == 0.0 {
|
||||
return Err(format!("The {} at index {} is 0.0", vname, i).into());
|
||||
}
|
||||
}
|
||||
if vx.len() != agg_targ_len {
|
||||
return Err(format!(
|
||||
"The length of {} ({}) does not match the length of {} ({})",
|
||||
vname, vx, agg_targ, agg_targ_len
|
||||
vname,
|
||||
vx.len(),
|
||||
agg_targ,
|
||||
agg_targ_len
|
||||
)
|
||||
.into());
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user