mirror of
https://github.com/Magnus167/msyrs.git
synced 2025-08-20 13:10:00 +00:00
Refactor normalization logic to apply NaN mask in a separate function
This commit is contained in:
parent
de62daaf8b
commit
3ca7221965
@ -1,13 +1,42 @@
|
|||||||
use crate::utils::misc::*;
|
|
||||||
use crate::utils::qdf::check_quantamental_dataframe;
|
use crate::utils::qdf::check_quantamental_dataframe;
|
||||||
use crate::utils::qdf::pivots::*;
|
use crate::utils::qdf::pivots::*;
|
||||||
use crate::utils::qdf::reduce_df::*;
|
use crate::utils::qdf::reduce_df::*;
|
||||||
use chrono::NaiveDate;
|
|
||||||
use ndarray::Data;
|
|
||||||
use ndarray::{s, Array, Array1, Zip};
|
|
||||||
use polars::prelude::*;
|
use polars::prelude::*;
|
||||||
use polars::series::Series; // Series struct
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
const TOLERANCE: f64 = 1e-8;
|
||||||
|
|
||||||
|
fn _norm_apply_nan_mask(
|
||||||
|
weights_dfw: &DataFrame,
|
||||||
|
nan_mask_dfw: &DataFrame,
|
||||||
|
) -> Result<DataFrame, PolarsError> {
|
||||||
|
let mut norm_dfw = DataFrame::new(vec![])?;
|
||||||
|
for col in weights_dfw.get_column_names().clone() {
|
||||||
|
let weight_col = weights_dfw.column(col)?;
|
||||||
|
let nan_mask_col = nan_mask_dfw.column(col)?;
|
||||||
|
let nan_mask_series = (!(nan_mask_col.bool()?)).into_column();
|
||||||
|
let float_col = nan_mask_series.cast(&DataType::Float64)?;
|
||||||
|
let float_series = (weight_col * &float_col)?;
|
||||||
|
let sum = float_series.f64()?.sum().unwrap();
|
||||||
|
if sum < TOLERANCE {
|
||||||
|
// get the sum of the nan_mask_series
|
||||||
|
let nan_mask_sum = nan_mask_series
|
||||||
|
.cast(&DataType::Float64)?
|
||||||
|
.f64()?
|
||||||
|
.sum()
|
||||||
|
.unwrap();
|
||||||
|
// if the sum of the nan_mask_series is close to the len of the series, then ok
|
||||||
|
let weight_sum = weight_col.f64()?.sum().unwrap();
|
||||||
|
let ratio = (nan_mask_sum - weight_col.len() as f64).abs() / weight_col.len() as f64;
|
||||||
|
assert!(
|
||||||
|
ratio < TOLERANCE,
|
||||||
|
"The sum of the updated series is zero: {:?} w ratio: {:?} and the weights summing to: {:?}",
|
||||||
|
sum, nan_mask_sum, weight_sum
|
||||||
|
);
|
||||||
|
}
|
||||||
|
norm_dfw.with_column(float_series)?;
|
||||||
|
}
|
||||||
|
Ok(norm_dfw)
|
||||||
|
}
|
||||||
|
|
||||||
fn normalize_weights_with_nan_mask(
|
fn normalize_weights_with_nan_mask(
|
||||||
weights_dfw: DataFrame,
|
weights_dfw: DataFrame,
|
||||||
@ -23,35 +52,14 @@ fn normalize_weights_with_nan_mask(
|
|||||||
"The columns in weights_dfw and nan_mask_dfw do not match."
|
"The columns in weights_dfw and nan_mask_dfw do not match."
|
||||||
);
|
);
|
||||||
|
|
||||||
let col_names = weights_dfw.get_column_names();
|
let mut norm_dfw = _norm_apply_nan_mask(&weights_dfw, &nan_mask_dfw)?;
|
||||||
|
|
||||||
// clone weights_dfw
|
|
||||||
let mut norm_dfw = DataFrame::new(vec![])?;
|
|
||||||
// Iterate over each column in the DataFrame, and apply the nan mask
|
|
||||||
for col in col_names.clone() {
|
|
||||||
let weight_series = weights_dfw.column(col)?;
|
|
||||||
let nan_mask_series = nan_mask_dfw.column(col)?;
|
|
||||||
|
|
||||||
// Replace values in `weight_series` with nulls where `nan_mask_series` is true
|
|
||||||
let updated_series = weight_series.zip_with(
|
|
||||||
nan_mask_series.bool()?,
|
|
||||||
&Series::full_null(
|
|
||||||
col.to_string().into(),
|
|
||||||
weight_series.len(),
|
|
||||||
weight_series.dtype(),
|
|
||||||
)
|
|
||||||
.into_column(),
|
|
||||||
)?;
|
|
||||||
|
|
||||||
// Update the column in the new DataFrame
|
|
||||||
norm_dfw.with_column(updated_series)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
let sums_of_rows = norm_dfw
|
let sums_of_rows = norm_dfw
|
||||||
.sum_horizontal(polars::frame::NullStrategy::Ignore)?
|
.sum_horizontal(polars::frame::NullStrategy::Ignore)?
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.into_column();
|
.into_column();
|
||||||
for col in col_names {
|
|
||||||
|
for col in weights_dfw.get_column_names() {
|
||||||
let weight_series = norm_dfw.column(col)?;
|
let weight_series = norm_dfw.column(col)?;
|
||||||
let divided_series = weight_series.divide(&sums_of_rows)?;
|
let divided_series = weight_series.divide(&sums_of_rows)?;
|
||||||
let normalized_series = divided_series.as_series().unwrap().clone();
|
let normalized_series = divided_series.as_series().unwrap().clone();
|
||||||
|
Loading…
x
Reference in New Issue
Block a user