mirror of
https://github.com/Magnus167/msyrs.git
synced 2025-08-20 07:20:01 +00:00
wip: normalization not wroking
This commit is contained in:
parent
0b97e2d0be
commit
e780f34188
@ -10,72 +10,55 @@ use polars::series::Series; // Series struct
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn normalize_weights_with_nan_mask(
|
||||
mut weights_dfw: DataFrame,
|
||||
weights_dfw: DataFrame,
|
||||
nan_mask_dfw: DataFrame,
|
||||
) -> Result<DataFrame, PolarsError> {
|
||||
let column_names: Vec<String> = weights_dfw
|
||||
.get_column_names()
|
||||
.iter()
|
||||
.map(|s| s.to_string())
|
||||
.collect();
|
||||
for col in column_names {
|
||||
let series = weights_dfw.column(&col)?.clone();
|
||||
let mask = nan_mask_dfw.column(&col)?.bool()?;
|
||||
let masked = series.zip_with(
|
||||
mask,
|
||||
&Series::full_null(col.to_string().into(), series.len(), series.dtype()).into_column(),
|
||||
)?;
|
||||
let masked = masked.as_series().unwrap();
|
||||
weights_dfw.replace(&col, masked.clone())?;
|
||||
}
|
||||
// get the length of weights_dfw
|
||||
// Ensure dimensions and columns match
|
||||
assert_eq!(weights_dfw.shape(), nan_mask_dfw.shape());
|
||||
assert!(
|
||||
weights_dfw.height() > 0,
|
||||
"weights_dfw is empty after masking."
|
||||
);
|
||||
let row_sums: Vec<f64> = (0..weights_dfw.height())
|
||||
.map(|i| {
|
||||
weights_dfw
|
||||
.get_columns()
|
||||
.iter()
|
||||
.filter_map(|s| s.f64().unwrap().get(i))
|
||||
.map(f64::abs)
|
||||
.sum()
|
||||
})
|
||||
.collect();
|
||||
|
||||
// get the height of row_sums
|
||||
assert!(row_sums.len() > 0, "row_sums is empty after summation.");
|
||||
|
||||
let column_names: Vec<String> = weights_dfw
|
||||
.get_column_names()
|
||||
.iter()
|
||||
.map(|s| s.to_string())
|
||||
.collect();
|
||||
for col in column_names {
|
||||
let normalized: Vec<Option<f64>> = weights_dfw
|
||||
.column(&col)?
|
||||
.f64()?
|
||||
.into_iter()
|
||||
.zip(&row_sums)
|
||||
.map(|(opt_val, sum)| {
|
||||
// If non-null and row sum is nonzero, normalize; else leave as null.
|
||||
opt_val.and_then(|v| if *sum != 0.0 { Some(v / sum) } else { None })
|
||||
})
|
||||
.collect();
|
||||
weights_dfw.replace(&col, Series::new(col.to_string().into(), normalized))?;
|
||||
}
|
||||
// Check if the DataFrame is empty after normalization
|
||||
assert!(
|
||||
weights_dfw.height() > 0,
|
||||
"weights_dfw is empty after normalization."
|
||||
);
|
||||
assert!(
|
||||
weights_dfw.get_column_names().len() > 0,
|
||||
"weights_dfw has no columns after normalization."
|
||||
.all(|col| nan_mask_dfw.get_column_names().contains(col)),
|
||||
"The columns in weights_dfw and nan_mask_dfw do not match."
|
||||
);
|
||||
|
||||
Ok(weights_dfw)
|
||||
let col_names = weights_dfw.get_column_names();
|
||||
|
||||
// clone weights_dfw
|
||||
let mut norm_dfw = DataFrame::new(vec![])?;
|
||||
// Iterate over each column in the DataFrame, and apply the nan mask
|
||||
for col in col_names.clone() {
|
||||
let weight_series = weights_dfw.column(col)?;
|
||||
let nan_mask_series = nan_mask_dfw.column(col)?;
|
||||
|
||||
// Replace values in `weight_series` with nulls where `nan_mask_series` is true
|
||||
let updated_series = weight_series.zip_with(
|
||||
nan_mask_series.bool()?,
|
||||
&Series::full_null(
|
||||
col.to_string().into(),
|
||||
weight_series.len(),
|
||||
weight_series.dtype(),
|
||||
)
|
||||
.into_column(),
|
||||
)?;
|
||||
|
||||
// Update the column in the new DataFrame
|
||||
norm_dfw.with_column(updated_series)?;
|
||||
}
|
||||
|
||||
let sums_of_rows = norm_dfw
|
||||
.sum_horizontal(polars::frame::NullStrategy::Ignore)?
|
||||
.unwrap()
|
||||
.into_column();
|
||||
for col in col_names {
|
||||
let weight_series = norm_dfw.column(col)?;
|
||||
let divided_series = weight_series.divide(&sums_of_rows)?;
|
||||
let normalized_series = divided_series.as_series().unwrap().clone();
|
||||
norm_dfw.replace(col, normalized_series)?;
|
||||
}
|
||||
|
||||
Ok(norm_dfw)
|
||||
}
|
||||
|
||||
fn _form_agg_data_dfw(dfw: &DataFrame, agg_targs: &Vec<String>) -> Result<DataFrame, PolarsError> {
|
||||
@ -270,17 +253,10 @@ fn get_agg_on_agg_targs(cids: Vec<String>, xcats: Vec<String>) -> (Vec<String>,
|
||||
fn get_mul_targets(
|
||||
cids: Vec<String>,
|
||||
xcats: Vec<String>,
|
||||
dfw: &DataFrame,
|
||||
) -> Result<HashMap<String, Vec<String>>, Box<dyn std::error::Error>> {
|
||||
let _agg_xcats_for_cid = agg_xcats_for_cid(cids.clone(), xcats.clone());
|
||||
let mut mul_targets = HashMap::new();
|
||||
|
||||
let found_tickers = dfw
|
||||
.get_column_names()
|
||||
.iter()
|
||||
.map(|name| name.to_string())
|
||||
.collect::<Vec<String>>();
|
||||
|
||||
let (agg_on, agg_targ) = get_agg_on_agg_targs(cids.clone(), xcats.clone());
|
||||
|
||||
for agg_o in agg_on {
|
||||
@ -392,6 +368,7 @@ fn rename_result_dfw_cols(
|
||||
/// Flags if the xcats are aggregated for a given cid.
|
||||
/// If true, the xcats are aggregated for each cid, creating a new xcat.
|
||||
/// If false, the cids are aggregated for each xcat, creating a new cid.
|
||||
#[allow(unused_variables)]
|
||||
fn agg_xcats_for_cid(cids: Vec<String>, xcats: Vec<String>) -> bool {
|
||||
// if there is more than 1 xcat, return xcats.len() > 1
|
||||
xcats.len() > 1
|
||||
@ -449,7 +426,7 @@ pub fn linear_composite(
|
||||
|
||||
let dfw = pivot_dataframe_by_ticker(rdf.clone(), Some("value".to_string())).unwrap();
|
||||
|
||||
let mul_targets = get_mul_targets(cids.clone(), xcats.clone(), &dfw)?;
|
||||
let mul_targets = get_mul_targets(cids.clone(), xcats.clone())?;
|
||||
let weights_map = form_weights_and_signs_map(cids.clone(), xcats.clone(), weights, signs)?;
|
||||
|
||||
for (ticker, targets) in mul_targets.iter() {
|
||||
|
Loading…
x
Reference in New Issue
Block a user