From de62daaf8b4410d83ccbc195c704010bad29fe43 Mon Sep 17 00:00:00 2001 From: Palash Tyagi <23239946+Magnus167@users.noreply.github.com> Date: Thu, 10 Apr 2025 00:24:00 +0100 Subject: [PATCH] Refactor check_weights_signs_lengths to enforce f64 types and validate non-zero weights and signs --- src/panel/linear_composite.rs | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/src/panel/linear_composite.rs b/src/panel/linear_composite.rs index f135172..6198ce1 100644 --- a/src/panel/linear_composite.rs +++ b/src/panel/linear_composite.rs @@ -319,9 +319,9 @@ fn form_weights_and_signs_map( Ok(weights_map) } -fn check_weights_signs_lengths( - weights_vec: Vec, - signs_vec: Vec, +fn check_weights_signs_lengths( + weights_vec: Vec, + signs_vec: Vec, _agg_xcats_for_cid: bool, agg_targ_len: usize, ) -> Result<(), Box> { @@ -330,11 +330,22 @@ fn check_weights_signs_lengths( true => "xcats", false => "cids", }; - for (vx, vname) in vec![(weights_vec.len(), "weights"), (signs_vec.len(), "signs")] { - if vx != agg_targ_len { + for (vx, vname) in vec![ + (weights_vec.clone(), "weights"), + (signs_vec.clone(), "signs"), + ] { + for (i, v) in vx.iter().enumerate() { + if *v == 0.0 { + return Err(format!("The {} at index {} is 0.0", vname, i).into()); + } + } + if vx.len() != agg_targ_len { return Err(format!( "The length of {} ({}) does not match the length of {} ({})", - vname, vx, agg_targ, agg_targ_len + vname, + vx.len(), + agg_targ, + agg_targ_len ) .into()); }