diff --git a/src/panel/linear_composite.rs b/src/panel/linear_composite.rs index 6198ce1..728b933 100644 --- a/src/panel/linear_composite.rs +++ b/src/panel/linear_composite.rs @@ -1,13 +1,42 @@ -use crate::utils::misc::*; use crate::utils::qdf::check_quantamental_dataframe; use crate::utils::qdf::pivots::*; use crate::utils::qdf::reduce_df::*; -use chrono::NaiveDate; -use ndarray::Data; -use ndarray::{s, Array, Array1, Zip}; use polars::prelude::*; -use polars::series::Series; // Series struct use std::collections::HashMap; +const TOLERANCE: f64 = 1e-8; + +fn _norm_apply_nan_mask( + weights_dfw: &DataFrame, + nan_mask_dfw: &DataFrame, +) -> Result { + let mut norm_dfw = DataFrame::new(vec![])?; + for col in weights_dfw.get_column_names().clone() { + let weight_col = weights_dfw.column(col)?; + let nan_mask_col = nan_mask_dfw.column(col)?; + let nan_mask_series = (!(nan_mask_col.bool()?)).into_column(); + let float_col = nan_mask_series.cast(&DataType::Float64)?; + let float_series = (weight_col * &float_col)?; + let sum = float_series.f64()?.sum().unwrap(); + if sum < TOLERANCE { + // get the sum of the nan_mask_series + let nan_mask_sum = nan_mask_series + .cast(&DataType::Float64)? + .f64()? + .sum() + .unwrap(); + // if the sum of the nan_mask_series is close to the len of the series, then ok + let weight_sum = weight_col.f64()?.sum().unwrap(); + let ratio = (nan_mask_sum - weight_col.len() as f64).abs() / weight_col.len() as f64; + assert!( + ratio < TOLERANCE, + "The sum of the updated series is zero: {:?} w ratio: {:?} and the weights summing to: {:?}", + sum, nan_mask_sum, weight_sum + ); + } + norm_dfw.with_column(float_series)?; + } + Ok(norm_dfw) +} fn normalize_weights_with_nan_mask( weights_dfw: DataFrame, @@ -23,35 +52,14 @@ fn normalize_weights_with_nan_mask( "The columns in weights_dfw and nan_mask_dfw do not match." ); - let col_names = weights_dfw.get_column_names(); - - // clone weights_dfw - let mut norm_dfw = DataFrame::new(vec![])?; - // Iterate over each column in the DataFrame, and apply the nan mask - for col in col_names.clone() { - let weight_series = weights_dfw.column(col)?; - let nan_mask_series = nan_mask_dfw.column(col)?; - - // Replace values in `weight_series` with nulls where `nan_mask_series` is true - let updated_series = weight_series.zip_with( - nan_mask_series.bool()?, - &Series::full_null( - col.to_string().into(), - weight_series.len(), - weight_series.dtype(), - ) - .into_column(), - )?; - - // Update the column in the new DataFrame - norm_dfw.with_column(updated_series)?; - } + let mut norm_dfw = _norm_apply_nan_mask(&weights_dfw, &nan_mask_dfw)?; let sums_of_rows = norm_dfw .sum_horizontal(polars::frame::NullStrategy::Ignore)? .unwrap() .into_column(); - for col in col_names { + + for col in weights_dfw.get_column_names() { let weight_series = norm_dfw.column(col)?; let divided_series = weight_series.divide(&sums_of_rows)?; let normalized_series = divided_series.as_series().unwrap().clone();