From e780f3418834793073e257fac2d9457867ad2473 Mon Sep 17 00:00:00 2001 From: Palash Tyagi <23239946+Magnus167@users.noreply.github.com> Date: Wed, 9 Apr 2025 00:21:08 +0100 Subject: [PATCH] wip: normalization not wroking --- src/panel/linear_composite.rs | 111 ++++++++++++++-------------------- 1 file changed, 44 insertions(+), 67 deletions(-) diff --git a/src/panel/linear_composite.rs b/src/panel/linear_composite.rs index 19f3f3f..f135172 100644 --- a/src/panel/linear_composite.rs +++ b/src/panel/linear_composite.rs @@ -10,72 +10,55 @@ use polars::series::Series; // Series struct use std::collections::HashMap; fn normalize_weights_with_nan_mask( - mut weights_dfw: DataFrame, + weights_dfw: DataFrame, nan_mask_dfw: DataFrame, ) -> Result { - let column_names: Vec = weights_dfw - .get_column_names() - .iter() - .map(|s| s.to_string()) - .collect(); - for col in column_names { - let series = weights_dfw.column(&col)?.clone(); - let mask = nan_mask_dfw.column(&col)?.bool()?; - let masked = series.zip_with( - mask, - &Series::full_null(col.to_string().into(), series.len(), series.dtype()).into_column(), + // Ensure dimensions and columns match + assert_eq!(weights_dfw.shape(), nan_mask_dfw.shape()); + assert!( + weights_dfw + .get_column_names() + .iter() + .all(|col| nan_mask_dfw.get_column_names().contains(col)), + "The columns in weights_dfw and nan_mask_dfw do not match." + ); + + let col_names = weights_dfw.get_column_names(); + + // clone weights_dfw + let mut norm_dfw = DataFrame::new(vec![])?; + // Iterate over each column in the DataFrame, and apply the nan mask + for col in col_names.clone() { + let weight_series = weights_dfw.column(col)?; + let nan_mask_series = nan_mask_dfw.column(col)?; + + // Replace values in `weight_series` with nulls where `nan_mask_series` is true + let updated_series = weight_series.zip_with( + nan_mask_series.bool()?, + &Series::full_null( + col.to_string().into(), + weight_series.len(), + weight_series.dtype(), + ) + .into_column(), )?; - let masked = masked.as_series().unwrap(); - weights_dfw.replace(&col, masked.clone())?; + + // Update the column in the new DataFrame + norm_dfw.with_column(updated_series)?; } - // get the length of weights_dfw - assert!( - weights_dfw.height() > 0, - "weights_dfw is empty after masking." - ); - let row_sums: Vec = (0..weights_dfw.height()) - .map(|i| { - weights_dfw - .get_columns() - .iter() - .filter_map(|s| s.f64().unwrap().get(i)) - .map(f64::abs) - .sum() - }) - .collect(); - // get the height of row_sums - assert!(row_sums.len() > 0, "row_sums is empty after summation."); - - let column_names: Vec = weights_dfw - .get_column_names() - .iter() - .map(|s| s.to_string()) - .collect(); - for col in column_names { - let normalized: Vec> = weights_dfw - .column(&col)? - .f64()? - .into_iter() - .zip(&row_sums) - .map(|(opt_val, sum)| { - // If non-null and row sum is nonzero, normalize; else leave as null. - opt_val.and_then(|v| if *sum != 0.0 { Some(v / sum) } else { None }) - }) - .collect(); - weights_dfw.replace(&col, Series::new(col.to_string().into(), normalized))?; + let sums_of_rows = norm_dfw + .sum_horizontal(polars::frame::NullStrategy::Ignore)? + .unwrap() + .into_column(); + for col in col_names { + let weight_series = norm_dfw.column(col)?; + let divided_series = weight_series.divide(&sums_of_rows)?; + let normalized_series = divided_series.as_series().unwrap().clone(); + norm_dfw.replace(col, normalized_series)?; } - // Check if the DataFrame is empty after normalization - assert!( - weights_dfw.height() > 0, - "weights_dfw is empty after normalization." - ); - assert!( - weights_dfw.get_column_names().len() > 0, - "weights_dfw has no columns after normalization." - ); - Ok(weights_dfw) + Ok(norm_dfw) } fn _form_agg_data_dfw(dfw: &DataFrame, agg_targs: &Vec) -> Result { @@ -270,17 +253,10 @@ fn get_agg_on_agg_targs(cids: Vec, xcats: Vec) -> (Vec, fn get_mul_targets( cids: Vec, xcats: Vec, - dfw: &DataFrame, ) -> Result>, Box> { let _agg_xcats_for_cid = agg_xcats_for_cid(cids.clone(), xcats.clone()); let mut mul_targets = HashMap::new(); - let found_tickers = dfw - .get_column_names() - .iter() - .map(|name| name.to_string()) - .collect::>(); - let (agg_on, agg_targ) = get_agg_on_agg_targs(cids.clone(), xcats.clone()); for agg_o in agg_on { @@ -392,6 +368,7 @@ fn rename_result_dfw_cols( /// Flags if the xcats are aggregated for a given cid. /// If true, the xcats are aggregated for each cid, creating a new xcat. /// If false, the cids are aggregated for each xcat, creating a new cid. +#[allow(unused_variables)] fn agg_xcats_for_cid(cids: Vec, xcats: Vec) -> bool { // if there is more than 1 xcat, return xcats.len() > 1 xcats.len() > 1 @@ -449,7 +426,7 @@ pub fn linear_composite( let dfw = pivot_dataframe_by_ticker(rdf.clone(), Some("value".to_string())).unwrap(); - let mul_targets = get_mul_targets(cids.clone(), xcats.clone(), &dfw)?; + let mul_targets = get_mul_targets(cids.clone(), xcats.clone())?; let weights_map = form_weights_and_signs_map(cids.clone(), xcats.clone(), weights, signs)?; for (ticker, targets) in mul_targets.iter() {