working except normalization

This commit is contained in:
Palash Tyagi 2025-04-08 22:37:20 +01:00
parent 9fad1fbae5
commit e37b4c19cb

View File

@ -9,48 +9,193 @@ use polars::prelude::*;
use polars::series::Series; // Series struct
use std::collections::HashMap;
fn normalize_weights_with_nan_mask(
mut weights_dfw: DataFrame,
nan_mask_dfw: DataFrame,
) -> Result<DataFrame, PolarsError> {
let column_names: Vec<String> = weights_dfw
.get_column_names()
.iter()
.map(|s| s.to_string())
.collect();
for col in column_names {
let series = weights_dfw.column(&col)?.clone();
let mask = nan_mask_dfw.column(&col)?.bool()?;
let masked = series.zip_with(
mask,
&Series::full_null(col.to_string().into(), series.len(), series.dtype()).into_column(),
)?;
let masked = masked.as_series().unwrap();
weights_dfw.replace(&col, masked.clone())?;
}
// get the length of weights_dfw
assert!(
weights_dfw.height() > 0,
"weights_dfw is empty after masking."
);
let row_sums: Vec<f64> = (0..weights_dfw.height())
.map(|i| {
weights_dfw
.get_columns()
.iter()
.filter_map(|s| s.f64().unwrap().get(i))
.map(f64::abs)
.sum()
})
.collect();
// get the height of row_sums
assert!(row_sums.len() > 0, "row_sums is empty after summation.");
let column_names: Vec<String> = weights_dfw
.get_column_names()
.iter()
.map(|s| s.to_string())
.collect();
for col in column_names {
let normalized: Vec<Option<f64>> = weights_dfw
.column(&col)?
.f64()?
.into_iter()
.zip(&row_sums)
.map(|(opt_val, sum)| {
// If non-null and row sum is nonzero, normalize; else leave as null.
opt_val.and_then(|v| if *sum != 0.0 { Some(v / sum) } else { None })
})
.collect();
weights_dfw.replace(&col, Series::new(col.to_string().into(), normalized))?;
}
// Check if the DataFrame is empty after normalization
assert!(
weights_dfw.height() > 0,
"weights_dfw is empty after normalization."
);
assert!(
weights_dfw.get_column_names().len() > 0,
"weights_dfw has no columns after normalization."
);
Ok(weights_dfw)
}
fn _form_agg_data_dfw(dfw: &DataFrame, agg_targs: &Vec<String>) -> Result<DataFrame, PolarsError> {
let mut data_dfw = DataFrame::new(vec![])?;
for agg_targ in agg_targs {
match dfw.column(agg_targ) {
Ok(agg_targ_series) => {
// If the column exists, clone it and add it to data_dfw
data_dfw.with_column(agg_targ_series.clone())?;
}
Err(_) => {
// If the column does not exist, create a series full of NaNs
let nan_series =
Series::full_null(agg_targ.into(), dfw.height(), &DataType::Float64);
data_dfw.with_column(nan_series)?;
}
}
}
Ok(data_dfw)
}
fn _form_agg_nan_mask_dfw(data_dfw: &DataFrame) -> Result<DataFrame, PolarsError> {
// Create a NaN mask DataFrame
let mut nan_mask_dfw = DataFrame::new(vec![])?;
for column in data_dfw.get_columns() {
let nan_mask_series = column.is_null().into_series();
nan_mask_dfw.with_column(nan_mask_series)?;
}
Ok(nan_mask_dfw)
}
fn _form_agg_nan_mask_series(nan_mask_dfw: &DataFrame) -> Result<Series, PolarsError> {
// perform a row-wise OR
let columns = nan_mask_dfw.get_columns();
let mut combined = columns[0].bool()?.clone();
for column in columns.iter().skip(1) {
combined = &combined | column.bool()?;
}
Ok(combined.into_series())
}
fn _form_agg_weights_dfw(
agg_weights_map: &HashMap<String, Vec<f64>>,
data_dfw: DataFrame,
) -> Result<DataFrame, PolarsError> {
let mut weights_dfw = DataFrame::new(vec![])?;
for (agg_targ, weight_signs) in agg_weights_map.iter() {
let wgt = weight_signs[0] * weight_signs[1];
let wgt_series = Series::new(agg_targ.into(), vec![wgt; data_dfw.height()]);
weights_dfw.with_column(wgt_series)?;
}
Ok(weights_dfw)
}
fn _assert_weights_and_data_match(
data_dfw: &DataFrame,
weights_dfw: &DataFrame,
) -> Result<(), PolarsError> {
for col in weights_dfw.get_column_names() {
if !data_dfw.get_column_names().contains(&col) {
return Err(PolarsError::ComputeError(
format!(
"weights_dfw and data_dfw do not have the same set of columns: {}",
col
)
.into(),
));
}
}
Ok(())
}
fn perform_single_group_agg(
dfw: &DataFrame,
new_dfw: &DataFrame,
agg_on: &String,
agg_targs: &Vec<String>,
agg_weights_map: &HashMap<String, Vec<f64>>,
normalize_weights: bool,
complete: bool,
) -> Result<DataFrame, PolarsError> {
// Replace this with the actual implementation
// get all agg_targs as columns
) -> Result<Column, PolarsError> {
let data_dfw = _form_agg_data_dfw(dfw, agg_targs)?;
let nan_mask_dfw = _form_agg_nan_mask_dfw(&data_dfw)?;
let nan_mask_series = _form_agg_nan_mask_series(&nan_mask_dfw)?;
let weights_dfw = _form_agg_weights_dfw(agg_weights_map, data_dfw.clone())?;
let weights_dfw = match normalize_weights {
true => normalize_weights_with_nan_mask(weights_dfw, nan_mask_dfw)?,
false => weights_dfw,
};
let mut weights_dfw = DataFrame::new(vec![])?; // Placeholder for weights DataFrame
for (agg_targ, weight_signs) in agg_weights_map.iter() {
let wgt = weight_signs[0] * weight_signs[1];
let wgt_series = Series::new(agg_targ.into(), vec![wgt]);
weights_dfw.with_column(wgt_series)?;
println!("weights_dfw: {:?}", weights_dfw);
// assert weights and data have the same set of columns
_assert_weights_and_data_match(&data_dfw, &weights_dfw)?;
// let output_dfw = weights_dfw * data_dfw;
let mut output_columns = Vec::with_capacity(data_dfw.get_column_names().len());
for col_name in data_dfw.get_column_names() {
let data_series = data_dfw.column(col_name)?;
let weight_series = weights_dfw.column(col_name)?;
let mut multiplied = (data_series * weight_series)?;
multiplied.rename(col_name.to_string().into());
output_columns.push(multiplied);
}
let mut data_dfw = DataFrame::new(vec![])?; // Placeholder for target DataFrame
for agg_targ in agg_targs {
if !dfw.get_column_names().contains(&&PlSmallStr::from_string(agg_targ.to_string())) {
continue;
}
let agg_targ_series = dfw.column(agg_targ)?.clone();
data_dfw.with_column(agg_targ_series)?;
let mut sum_series = output_columns[0].clone();
for i in 1..output_columns.len() {
let filled_column = output_columns[i].fill_null(FillNullStrategy::Zero)?;
sum_series = (&sum_series + &filled_column)?;
}
// nan_mask = [iter over data_dfw.columns() applying is_nan()] OR [iter over data_dfw.rows() applying is_nan()]
let mut nan_mask = DataFrame::new(vec![])?; // Placeholder for NaN mask DataFrame
for col in data_dfw.get_column_names() {
let col_series = data_dfw.column(col)?;
let nan_mask_series = col_series.is_nan()?
.cast(&DataType::Boolean)?
.into_series();
nan_mask.with_column(nan_mask_series)?;
if complete {
sum_series = sum_series.zip_with(
nan_mask_series.bool()?,
&Series::full_null(agg_on.clone().into(), sum_series.len(), sum_series.dtype())
.into_column(),
)?;
}
sum_series.rename(agg_on.clone().into());
// new_dfw.with_column(sum_series)?;
Ok(new_dfw.clone())
Ok(sum_series)
}
fn perform_multiplication(
@ -60,21 +205,38 @@ fn perform_multiplication(
complete: bool,
normalize_weights: bool,
) -> Result<DataFrame, PolarsError> {
let real_date = dfw.column("real_date")?.clone();
let mut new_dfw = DataFrame::new(vec![real_date])?;
// let real_date = dfw.column("real_date".into())?.clone();
// let mut new_dfw = DataFrame::new(vec![real_date])?;
let mut new_dfw = DataFrame::new(vec![])?;
assert!(!mult_targets.is_empty(), "agg_targs is empty");
for (agg_on, agg_targs) in mult_targets.iter() {
// perform_single_group_agg
perform_single_group_agg(
let cols_len = new_dfw.get_column_names().len();
let new_col = perform_single_group_agg(
dfw,
&new_dfw,
agg_on,
agg_targs,
&weights_map[agg_on],
normalize_weights,
complete,
)?;
// assert that the number of columns has grown
assert!(
new_col.len() != 0,
"The new DataFrame is empty after aggregation."
);
new_dfw.with_column(new_col)?;
// if the height of new_dfw is 0,
let new_cols_len = new_dfw.get_column_names().len();
assert!(
new_cols_len > cols_len,
"The number of columns did not grow after aggregation."
);
}
// get the real_date column from dfw
let real_date = dfw.column("real_date".into())?.clone();
new_dfw.with_column(real_date.clone())?;
// Placeholder logic to return a valid DataFrame for now
// Replace this with the actual implementation
@ -125,17 +287,20 @@ fn get_mul_targets(
let mut targets = Vec::new();
for agg_t in &agg_targ {
let ticker = match _agg_xcats_for_cid {
true => format!("{}_{}", agg_t, agg_o),
false => format!("{}_{}", agg_o, agg_t),
true => format!("{}_{}", agg_o, agg_t),
false => format!("{}_{}", agg_t, agg_o),
};
if found_tickers.contains(&ticker) {
targets.push(ticker);
}
targets.push(ticker);
}
if !targets.is_empty() {
mul_targets.insert(agg_o.clone(), targets);
}
}
// check if mul_targets is empty
assert!(
!mul_targets.is_empty(),
"The mul_targets is empty. Please check the input DataFrame."
);
Ok(mul_targets)
}
@ -150,7 +315,7 @@ fn form_weights_and_signs_map(
let (agg_on, agg_targ) = get_agg_on_agg_targs(cids.clone(), xcats.clone());
// if weights are None, create a vector of 1s of the same length as agg_targ
let weights = weights.unwrap_or(vec![1.0; agg_targ.len()]);
let weights = weights.unwrap_or(vec![1.0 / agg_targ.len() as f64; agg_targ.len()]);
let signs = signs.unwrap_or(vec![1.0; agg_targ.len()]);
// check that the lengths of weights and signs match the length of agg_targ
@ -167,8 +332,8 @@ fn form_weights_and_signs_map(
let mut agg_t_map = HashMap::new();
for (i, agg_t) in agg_targ.iter().enumerate() {
let ticker = match _agg_xcats_for_cid {
true => format!("{}_{}", agg_t, agg_o),
false => format!("{}_{}", agg_o, agg_t),
true => format!("{}_{}", agg_o, agg_t),
false => format!("{}_{}", agg_t, agg_o),
};
let weight_signs = vec![weights[i], signs[i]];
agg_t_map.insert(ticker, weight_signs);
@ -200,6 +365,29 @@ fn check_weights_signs_lengths<T>(
}
Ok(())
}
fn rename_result_dfw_cols(
xcats: Vec<String>,
cids: Vec<String>,
new_xcat: String,
new_cid: String,
dfw: &mut DataFrame,
) -> Result<(), Box<dyn std::error::Error>> {
let dfw_cols: Vec<String> = dfw
.get_column_names()
.iter()
.filter(|&s| *s != "real_date") // Exclude 'real_date'
.map(|s| s.to_string())
.collect();
let _agg_xcats_for_cid = agg_xcats_for_cid(cids.clone(), xcats.clone());
Ok(for col in dfw_cols {
let new_name = match _agg_xcats_for_cid {
true => format!("{}_{}", col, new_xcat),
false => format!("{}_{}", new_cid, col),
};
// rename the column
dfw.rename(&col, new_name.into())?;
})
}
/// Flags if the xcats are aggregated for a given cid.
/// If true, the xcats are aggregated for each cid, creating a new xcat.
@ -256,7 +444,10 @@ pub fn linear_composite(
)
.unwrap();
let mut dfw = pivot_dataframe_by_ticker(rdf, Some("value".to_string())).unwrap();
let new_xcat = new_xcat.unwrap_or_else(|| "COMPOSITE".to_string());
let new_cid = new_cid.unwrap_or_else(|| "GLB".to_string());
let dfw = pivot_dataframe_by_ticker(rdf.clone(), Some("value".to_string())).unwrap();
let mul_targets = get_mul_targets(cids.clone(), xcats.clone(), &dfw)?;
let weights_map = form_weights_and_signs_map(cids.clone(), xcats.clone(), weights, signs)?;
@ -267,6 +458,19 @@ pub fn linear_composite(
for (agg_on, agg_t_map) in weights_map.iter() {
println!("agg_on: {}, agg_t_map: {:?}", agg_on, agg_t_map);
}
// assert!(0==1, "Debugging weights_map: {:?}", weights_map);
// Perform the multiplication
let complete = complete_xcats || complete_cids;
let mut dfw = perform_multiplication(
&dfw,
&mul_targets,
&weights_map,
complete,
normalize_weights,
)?;
rename_result_dfw_cols(xcats, cids, new_xcat, new_cid, &mut dfw)?;
let dfw = pivot_wide_dataframe_to_qdf(dfw, Some("value".to_string()))?;
Ok(dfw)
}