mirror of
https://github.com/Magnus167/msyrs.git
synced 2025-08-20 13:00:01 +00:00
wip: normalization not wroking
This commit is contained in:
parent
0b97e2d0be
commit
e780f34188
@ -10,72 +10,55 @@ use polars::series::Series; // Series struct
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
fn normalize_weights_with_nan_mask(
|
fn normalize_weights_with_nan_mask(
|
||||||
mut weights_dfw: DataFrame,
|
weights_dfw: DataFrame,
|
||||||
nan_mask_dfw: DataFrame,
|
nan_mask_dfw: DataFrame,
|
||||||
) -> Result<DataFrame, PolarsError> {
|
) -> Result<DataFrame, PolarsError> {
|
||||||
let column_names: Vec<String> = weights_dfw
|
// Ensure dimensions and columns match
|
||||||
.get_column_names()
|
assert_eq!(weights_dfw.shape(), nan_mask_dfw.shape());
|
||||||
.iter()
|
|
||||||
.map(|s| s.to_string())
|
|
||||||
.collect();
|
|
||||||
for col in column_names {
|
|
||||||
let series = weights_dfw.column(&col)?.clone();
|
|
||||||
let mask = nan_mask_dfw.column(&col)?.bool()?;
|
|
||||||
let masked = series.zip_with(
|
|
||||||
mask,
|
|
||||||
&Series::full_null(col.to_string().into(), series.len(), series.dtype()).into_column(),
|
|
||||||
)?;
|
|
||||||
let masked = masked.as_series().unwrap();
|
|
||||||
weights_dfw.replace(&col, masked.clone())?;
|
|
||||||
}
|
|
||||||
// get the length of weights_dfw
|
|
||||||
assert!(
|
assert!(
|
||||||
weights_dfw.height() > 0,
|
|
||||||
"weights_dfw is empty after masking."
|
|
||||||
);
|
|
||||||
let row_sums: Vec<f64> = (0..weights_dfw.height())
|
|
||||||
.map(|i| {
|
|
||||||
weights_dfw
|
weights_dfw
|
||||||
.get_columns()
|
|
||||||
.iter()
|
|
||||||
.filter_map(|s| s.f64().unwrap().get(i))
|
|
||||||
.map(f64::abs)
|
|
||||||
.sum()
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
// get the height of row_sums
|
|
||||||
assert!(row_sums.len() > 0, "row_sums is empty after summation.");
|
|
||||||
|
|
||||||
let column_names: Vec<String> = weights_dfw
|
|
||||||
.get_column_names()
|
.get_column_names()
|
||||||
.iter()
|
.iter()
|
||||||
.map(|s| s.to_string())
|
.all(|col| nan_mask_dfw.get_column_names().contains(col)),
|
||||||
.collect();
|
"The columns in weights_dfw and nan_mask_dfw do not match."
|
||||||
for col in column_names {
|
|
||||||
let normalized: Vec<Option<f64>> = weights_dfw
|
|
||||||
.column(&col)?
|
|
||||||
.f64()?
|
|
||||||
.into_iter()
|
|
||||||
.zip(&row_sums)
|
|
||||||
.map(|(opt_val, sum)| {
|
|
||||||
// If non-null and row sum is nonzero, normalize; else leave as null.
|
|
||||||
opt_val.and_then(|v| if *sum != 0.0 { Some(v / sum) } else { None })
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
weights_dfw.replace(&col, Series::new(col.to_string().into(), normalized))?;
|
|
||||||
}
|
|
||||||
// Check if the DataFrame is empty after normalization
|
|
||||||
assert!(
|
|
||||||
weights_dfw.height() > 0,
|
|
||||||
"weights_dfw is empty after normalization."
|
|
||||||
);
|
|
||||||
assert!(
|
|
||||||
weights_dfw.get_column_names().len() > 0,
|
|
||||||
"weights_dfw has no columns after normalization."
|
|
||||||
);
|
);
|
||||||
|
|
||||||
Ok(weights_dfw)
|
let col_names = weights_dfw.get_column_names();
|
||||||
|
|
||||||
|
// clone weights_dfw
|
||||||
|
let mut norm_dfw = DataFrame::new(vec![])?;
|
||||||
|
// Iterate over each column in the DataFrame, and apply the nan mask
|
||||||
|
for col in col_names.clone() {
|
||||||
|
let weight_series = weights_dfw.column(col)?;
|
||||||
|
let nan_mask_series = nan_mask_dfw.column(col)?;
|
||||||
|
|
||||||
|
// Replace values in `weight_series` with nulls where `nan_mask_series` is true
|
||||||
|
let updated_series = weight_series.zip_with(
|
||||||
|
nan_mask_series.bool()?,
|
||||||
|
&Series::full_null(
|
||||||
|
col.to_string().into(),
|
||||||
|
weight_series.len(),
|
||||||
|
weight_series.dtype(),
|
||||||
|
)
|
||||||
|
.into_column(),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// Update the column in the new DataFrame
|
||||||
|
norm_dfw.with_column(updated_series)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let sums_of_rows = norm_dfw
|
||||||
|
.sum_horizontal(polars::frame::NullStrategy::Ignore)?
|
||||||
|
.unwrap()
|
||||||
|
.into_column();
|
||||||
|
for col in col_names {
|
||||||
|
let weight_series = norm_dfw.column(col)?;
|
||||||
|
let divided_series = weight_series.divide(&sums_of_rows)?;
|
||||||
|
let normalized_series = divided_series.as_series().unwrap().clone();
|
||||||
|
norm_dfw.replace(col, normalized_series)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(norm_dfw)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn _form_agg_data_dfw(dfw: &DataFrame, agg_targs: &Vec<String>) -> Result<DataFrame, PolarsError> {
|
fn _form_agg_data_dfw(dfw: &DataFrame, agg_targs: &Vec<String>) -> Result<DataFrame, PolarsError> {
|
||||||
@ -270,17 +253,10 @@ fn get_agg_on_agg_targs(cids: Vec<String>, xcats: Vec<String>) -> (Vec<String>,
|
|||||||
fn get_mul_targets(
|
fn get_mul_targets(
|
||||||
cids: Vec<String>,
|
cids: Vec<String>,
|
||||||
xcats: Vec<String>,
|
xcats: Vec<String>,
|
||||||
dfw: &DataFrame,
|
|
||||||
) -> Result<HashMap<String, Vec<String>>, Box<dyn std::error::Error>> {
|
) -> Result<HashMap<String, Vec<String>>, Box<dyn std::error::Error>> {
|
||||||
let _agg_xcats_for_cid = agg_xcats_for_cid(cids.clone(), xcats.clone());
|
let _agg_xcats_for_cid = agg_xcats_for_cid(cids.clone(), xcats.clone());
|
||||||
let mut mul_targets = HashMap::new();
|
let mut mul_targets = HashMap::new();
|
||||||
|
|
||||||
let found_tickers = dfw
|
|
||||||
.get_column_names()
|
|
||||||
.iter()
|
|
||||||
.map(|name| name.to_string())
|
|
||||||
.collect::<Vec<String>>();
|
|
||||||
|
|
||||||
let (agg_on, agg_targ) = get_agg_on_agg_targs(cids.clone(), xcats.clone());
|
let (agg_on, agg_targ) = get_agg_on_agg_targs(cids.clone(), xcats.clone());
|
||||||
|
|
||||||
for agg_o in agg_on {
|
for agg_o in agg_on {
|
||||||
@ -392,6 +368,7 @@ fn rename_result_dfw_cols(
|
|||||||
/// Flags if the xcats are aggregated for a given cid.
|
/// Flags if the xcats are aggregated for a given cid.
|
||||||
/// If true, the xcats are aggregated for each cid, creating a new xcat.
|
/// If true, the xcats are aggregated for each cid, creating a new xcat.
|
||||||
/// If false, the cids are aggregated for each xcat, creating a new cid.
|
/// If false, the cids are aggregated for each xcat, creating a new cid.
|
||||||
|
#[allow(unused_variables)]
|
||||||
fn agg_xcats_for_cid(cids: Vec<String>, xcats: Vec<String>) -> bool {
|
fn agg_xcats_for_cid(cids: Vec<String>, xcats: Vec<String>) -> bool {
|
||||||
// if there is more than 1 xcat, return xcats.len() > 1
|
// if there is more than 1 xcat, return xcats.len() > 1
|
||||||
xcats.len() > 1
|
xcats.len() > 1
|
||||||
@ -449,7 +426,7 @@ pub fn linear_composite(
|
|||||||
|
|
||||||
let dfw = pivot_dataframe_by_ticker(rdf.clone(), Some("value".to_string())).unwrap();
|
let dfw = pivot_dataframe_by_ticker(rdf.clone(), Some("value".to_string())).unwrap();
|
||||||
|
|
||||||
let mul_targets = get_mul_targets(cids.clone(), xcats.clone(), &dfw)?;
|
let mul_targets = get_mul_targets(cids.clone(), xcats.clone())?;
|
||||||
let weights_map = form_weights_and_signs_map(cids.clone(), xcats.clone(), weights, signs)?;
|
let weights_map = form_weights_and_signs_map(cids.clone(), xcats.clone(), weights, signs)?;
|
||||||
|
|
||||||
for (ticker, targets) in mul_targets.iter() {
|
for (ticker, targets) in mul_targets.iter() {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user