mirror of
https://github.com/Magnus167/msyrs.git
synced 2025-08-20 13:00:01 +00:00
Refactor check_weights_signs_lengths to enforce f64 types and validate non-zero weights and signs
This commit is contained in:
parent
1a7d3c1491
commit
de62daaf8b
@ -319,9 +319,9 @@ fn form_weights_and_signs_map(
|
|||||||
Ok(weights_map)
|
Ok(weights_map)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn check_weights_signs_lengths<T>(
|
fn check_weights_signs_lengths(
|
||||||
weights_vec: Vec<T>,
|
weights_vec: Vec<f64>,
|
||||||
signs_vec: Vec<T>,
|
signs_vec: Vec<f64>,
|
||||||
_agg_xcats_for_cid: bool,
|
_agg_xcats_for_cid: bool,
|
||||||
agg_targ_len: usize,
|
agg_targ_len: usize,
|
||||||
) -> Result<(), Box<dyn std::error::Error>> {
|
) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
@ -330,11 +330,22 @@ fn check_weights_signs_lengths<T>(
|
|||||||
true => "xcats",
|
true => "xcats",
|
||||||
false => "cids",
|
false => "cids",
|
||||||
};
|
};
|
||||||
for (vx, vname) in vec![(weights_vec.len(), "weights"), (signs_vec.len(), "signs")] {
|
for (vx, vname) in vec![
|
||||||
if vx != agg_targ_len {
|
(weights_vec.clone(), "weights"),
|
||||||
|
(signs_vec.clone(), "signs"),
|
||||||
|
] {
|
||||||
|
for (i, v) in vx.iter().enumerate() {
|
||||||
|
if *v == 0.0 {
|
||||||
|
return Err(format!("The {} at index {} is 0.0", vname, i).into());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if vx.len() != agg_targ_len {
|
||||||
return Err(format!(
|
return Err(format!(
|
||||||
"The length of {} ({}) does not match the length of {} ({})",
|
"The length of {} ({}) does not match the length of {} ({})",
|
||||||
vname, vx, agg_targ, agg_targ_len
|
vname,
|
||||||
|
vx.len(),
|
||||||
|
agg_targ,
|
||||||
|
agg_targ_len
|
||||||
)
|
)
|
||||||
.into());
|
.into());
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user