Refactor normalization logic to apply NaN mask in a separate function

This commit is contained in:
Palash Tyagi 2025-04-10 00:24:30 +01:00
parent de62daaf8b
commit 3ca7221965

View File

@ -1,13 +1,42 @@
use crate::utils::misc::*;
use crate::utils::qdf::check_quantamental_dataframe; use crate::utils::qdf::check_quantamental_dataframe;
use crate::utils::qdf::pivots::*; use crate::utils::qdf::pivots::*;
use crate::utils::qdf::reduce_df::*; use crate::utils::qdf::reduce_df::*;
use chrono::NaiveDate;
use ndarray::Data;
use ndarray::{s, Array, Array1, Zip};
use polars::prelude::*; use polars::prelude::*;
use polars::series::Series; // Series struct
use std::collections::HashMap; use std::collections::HashMap;
const TOLERANCE: f64 = 1e-8;
fn _norm_apply_nan_mask(
weights_dfw: &DataFrame,
nan_mask_dfw: &DataFrame,
) -> Result<DataFrame, PolarsError> {
let mut norm_dfw = DataFrame::new(vec![])?;
for col in weights_dfw.get_column_names().clone() {
let weight_col = weights_dfw.column(col)?;
let nan_mask_col = nan_mask_dfw.column(col)?;
let nan_mask_series = (!(nan_mask_col.bool()?)).into_column();
let float_col = nan_mask_series.cast(&DataType::Float64)?;
let float_series = (weight_col * &float_col)?;
let sum = float_series.f64()?.sum().unwrap();
if sum < TOLERANCE {
// get the sum of the nan_mask_series
let nan_mask_sum = nan_mask_series
.cast(&DataType::Float64)?
.f64()?
.sum()
.unwrap();
// if the sum of the nan_mask_series is close to the len of the series, then ok
let weight_sum = weight_col.f64()?.sum().unwrap();
let ratio = (nan_mask_sum - weight_col.len() as f64).abs() / weight_col.len() as f64;
assert!(
ratio < TOLERANCE,
"The sum of the updated series is zero: {:?} w ratio: {:?} and the weights summing to: {:?}",
sum, nan_mask_sum, weight_sum
);
}
norm_dfw.with_column(float_series)?;
}
Ok(norm_dfw)
}
fn normalize_weights_with_nan_mask( fn normalize_weights_with_nan_mask(
weights_dfw: DataFrame, weights_dfw: DataFrame,
@ -23,35 +52,14 @@ fn normalize_weights_with_nan_mask(
"The columns in weights_dfw and nan_mask_dfw do not match." "The columns in weights_dfw and nan_mask_dfw do not match."
); );
let col_names = weights_dfw.get_column_names(); let mut norm_dfw = _norm_apply_nan_mask(&weights_dfw, &nan_mask_dfw)?;
// clone weights_dfw
let mut norm_dfw = DataFrame::new(vec![])?;
// Iterate over each column in the DataFrame, and apply the nan mask
for col in col_names.clone() {
let weight_series = weights_dfw.column(col)?;
let nan_mask_series = nan_mask_dfw.column(col)?;
// Replace values in `weight_series` with nulls where `nan_mask_series` is true
let updated_series = weight_series.zip_with(
nan_mask_series.bool()?,
&Series::full_null(
col.to_string().into(),
weight_series.len(),
weight_series.dtype(),
)
.into_column(),
)?;
// Update the column in the new DataFrame
norm_dfw.with_column(updated_series)?;
}
let sums_of_rows = norm_dfw let sums_of_rows = norm_dfw
.sum_horizontal(polars::frame::NullStrategy::Ignore)? .sum_horizontal(polars::frame::NullStrategy::Ignore)?
.unwrap() .unwrap()
.into_column(); .into_column();
for col in col_names {
for col in weights_dfw.get_column_names() {
let weight_series = norm_dfw.column(col)?; let weight_series = norm_dfw.column(col)?;
let divided_series = weight_series.divide(&sums_of_rows)?; let divided_series = weight_series.divide(&sums_of_rows)?;
let normalized_series = divided_series.as_series().unwrap().clone(); let normalized_series = divided_series.as_series().unwrap().clone();