use crate::utils::qdf::check_quantamental_dataframe; use crate::utils::qdf::pivots::{pivot_dataframe_by_ticker, pivot_wide_dataframe_to_qdf}; use crate::utils::qdf::reduce_df::reduce_dataframe; use polars::prelude::*; use std::collections::HashMap; const TOLERANCE: f64 = 1e-8; fn _norm_apply_nan_mask( weights_dfw: &DataFrame, nan_mask_dfw: &DataFrame, ) -> Result { let mut norm_dfw = DataFrame::new(vec![])?; for col in weights_dfw.get_column_names().clone() { let weight_col = weights_dfw.column(col)?; let nan_mask_col = nan_mask_dfw.column(col)?; let nan_mask_series = (!(nan_mask_col.bool()?)).into_column(); let float_col = nan_mask_series.cast(&DataType::Float64)?; let float_series = (weight_col * &float_col)?; let sum = float_series.f64()?.sum().unwrap(); if sum < TOLERANCE { // get the sum of the nan_mask_series let nan_mask_sum = nan_mask_series .cast(&DataType::Float64)? .f64()? .sum() .unwrap(); // if the sum of the nan_mask_series is close to the len of the series, then ok let weight_sum = weight_col.f64()?.sum().unwrap(); let ratio = (nan_mask_sum - weight_col.len() as f64).abs() / weight_col.len() as f64; assert!( ratio < TOLERANCE, "The sum of the updated series is zero: {:?} w ratio: {:?} and the weights summing to: {:?}", sum, nan_mask_sum, weight_sum ); } norm_dfw.with_column(float_series)?; } Ok(norm_dfw) } fn normalize_weights_with_nan_mask( weights_dfw: DataFrame, nan_mask_dfw: DataFrame, ) -> Result { // Ensure dimensions and columns match assert_eq!(weights_dfw.shape(), nan_mask_dfw.shape()); assert!( weights_dfw .get_column_names() .iter() .all(|col| nan_mask_dfw.get_column_names().contains(col)), "The columns in weights_dfw and nan_mask_dfw do not match." ); let mut norm_dfw = _norm_apply_nan_mask(&weights_dfw, &nan_mask_dfw)?; let sums_of_rows = norm_dfw .sum_horizontal(polars::frame::NullStrategy::Ignore)? .unwrap() .into_column(); for col in weights_dfw.get_column_names() { let weight_series = norm_dfw.column(col)?; let divided_series = weight_series.divide(&sums_of_rows)?; let normalized_series = divided_series.as_series().unwrap().clone(); norm_dfw.replace(col, normalized_series)?; } Ok(norm_dfw) } fn _form_agg_data_dfw(dfw: &DataFrame, agg_targs: &Vec) -> Result { let mut data_dfw = DataFrame::new(vec![])?; for agg_targ in agg_targs { match dfw.column(agg_targ) { Ok(agg_targ_series) => { // If the column exists, clone it and add it to data_dfw data_dfw.with_column(agg_targ_series.clone())?; } Err(_) => { // If the column does not exist, create a series full of NaNs let nan_series = Series::full_null(agg_targ.into(), dfw.height(), &DataType::Float64); data_dfw.with_column(nan_series)?; } } } Ok(data_dfw) } fn _form_agg_nan_mask_dfw(data_dfw: &DataFrame) -> Result { // Create a NaN mask DataFrame let mut nan_mask_dfw = DataFrame::new(vec![])?; for column in data_dfw.get_columns() { let nan_mask_series = column.is_null().into_series(); nan_mask_dfw.with_column(nan_mask_series)?; } Ok(nan_mask_dfw) } fn _form_agg_nan_mask_series(nan_mask_dfw: &DataFrame) -> Result { // perform a row-wise OR let columns = nan_mask_dfw.get_columns(); let mut combined = columns[0].bool()?.clone(); for column in columns.iter().skip(1) { combined = &combined | column.bool()?; } Ok(combined.into_series()) } /// Form the weights DataFrame fn _form_agg_weights_dfw( agg_weights_map: &HashMap, dfw: &DataFrame, ) -> Result { let mut weights_dfw = DataFrame::new(vec![])?; for (agg_targ, weight_signs) in agg_weights_map.iter() { // let wgt = weight_signs[0] * weight_signs[1]; let wgt_series = match &weight_signs.0 { WeightValue::F64(val) => { let wgt = val * weight_signs.1; Series::new(agg_targ.into(), vec![wgt; dfw.height()]) } WeightValue::Str(vstr) => { // vstr column from data_dfw, else raise wieght specification error if !dfw.get_column_names().contains(&&PlSmallStr::from(vstr)) { return Err(PolarsError::ComputeError( format!( "The column {} does not exist in the DataFrame. {:?}", vstr, agg_weights_map ) .into(), )); } let vstr_series = dfw.column(vstr)?; let multiplied_series = vstr_series * weight_signs.1; let mut multiplied_series = multiplied_series.as_series().cloned().ok_or_else(|| { PolarsError::ComputeError( "Failed to convert multiplied_series to Series".into(), ) })?; multiplied_series.rename(agg_targ.into()); multiplied_series } }; weights_dfw.with_column(wgt_series)?; } Ok(weights_dfw) } fn _assert_weights_and_data_match( data_dfw: &DataFrame, weights_dfw: &DataFrame, ) -> Result<(), PolarsError> { for col in weights_dfw.get_column_names() { if !data_dfw.get_column_names().contains(&col) { return Err(PolarsError::ComputeError( format!( "weights_dfw and data_dfw do not have the same set of columns: {}", col ) .into(), )); } } Ok(()) } fn perform_single_group_agg( dfw: &DataFrame, agg_on: &String, agg_targs: &Vec, agg_weights_map: &HashMap, normalize_weights: bool, complete: bool, ) -> Result { let data_dfw = _form_agg_data_dfw(dfw, agg_targs)?; let nan_mask_dfw = _form_agg_nan_mask_dfw(&data_dfw)?; let nan_mask_series = _form_agg_nan_mask_series(&nan_mask_dfw)?; let weights_dfw = _form_agg_weights_dfw(agg_weights_map, dfw)?; let weights_dfw = match normalize_weights { true => normalize_weights_with_nan_mask(weights_dfw, nan_mask_dfw)?, false => weights_dfw, }; println!("weights_dfw: {:?}", weights_dfw); // assert weights and data have the same set of columns _assert_weights_and_data_match(&data_dfw, &weights_dfw)?; // let output_dfw = weights_dfw * data_dfw; let mut output_columns = Vec::with_capacity(data_dfw.get_column_names().len()); for col_name in data_dfw.get_column_names() { let data_series = data_dfw.column(col_name)?; let weight_series = weights_dfw.column(col_name)?; let mut multiplied = (data_series * weight_series)?; multiplied.rename(col_name.to_string().into()); output_columns.push(multiplied); } let mut sum_series = output_columns[0].clone(); for i in 1..output_columns.len() { let filled_column = output_columns[i].fill_null(FillNullStrategy::Zero)?; sum_series = (&sum_series + &filled_column)?; } if complete { sum_series = sum_series.zip_with( nan_mask_series.bool()?, &Series::full_null(agg_on.clone().into(), sum_series.len(), sum_series.dtype()) .into_column(), )?; } sum_series.rename(agg_on.clone().into()); // new_dfw.with_column(sum_series)?; Ok(sum_series) } fn perform_multiplication( dfw: &DataFrame, mult_targets: &HashMap>, weights_map: &HashMap>, complete: bool, normalize_weights: bool, ) -> Result { // let real_date = dfw.column("real_date".into())?.clone(); // let mut new_dfw = DataFrame::new(vec![real_date])?; let mut new_dfw = DataFrame::new(vec![])?; assert!(!mult_targets.is_empty(), "agg_targs is empty"); for (agg_on, agg_targs) in mult_targets.iter() { // perform_single_group_agg let cols_len = new_dfw.get_column_names().len(); let new_col = perform_single_group_agg( dfw, agg_on, agg_targs, &weights_map[agg_on], normalize_weights, complete, )?; // assert that the number of columns has grown assert!( new_col.len() != 0, "The new DataFrame is empty after aggregation." ); new_dfw.with_column(new_col)?; // if the height of new_dfw is 0, let new_cols_len = new_dfw.get_column_names().len(); assert!( new_cols_len > cols_len, "The number of columns did not grow after aggregation." ); } // get the real_date column from dfw let real_date = dfw.column("real_date".into())?.clone(); new_dfw.with_column(real_date.clone())?; // Placeholder logic to return a valid DataFrame for now // Replace this with the actual implementation Ok(new_dfw) } fn get_agg_on_agg_targs(cids: Vec, xcats: Vec) -> (Vec, Vec) { let _agg_xcats_for_cid = agg_xcats_for_cid(cids.clone(), xcats.clone()); let (agg_on, agg_targ) = if _agg_xcats_for_cid { (cids.clone(), xcats.clone()) } else { (xcats.clone(), cids.clone()) }; // assert that if agg_xcats_for_cid is true, agg_on = cids match _agg_xcats_for_cid { true => { assert_eq!(agg_on, cids); assert_eq!(agg_targ, xcats); } false => { assert_eq!(agg_on, xcats); assert_eq!(agg_targ, cids); } } (agg_on, agg_targ) } /// Get the mapping of aggregation targets for the implied mode of aggregation. /// # Returns /// * `HashMap>` - A mapping of cid/xcat to the list of tickers to be aggregated. fn get_mul_targets( cids: Vec, xcats: Vec, ) -> Result>, Box> { let _agg_xcats_for_cid = agg_xcats_for_cid(cids.clone(), xcats.clone()); let mut mul_targets = HashMap::new(); let (agg_on, agg_targ) = get_agg_on_agg_targs(cids.clone(), xcats.clone()); for agg_o in agg_on { let mut targets = Vec::new(); for agg_t in &agg_targ { let ticker = match _agg_xcats_for_cid { true => format!("{}_{}", agg_o, agg_t), false => format!("{}_{}", agg_t, agg_o), }; targets.push(ticker); } if !targets.is_empty() { mul_targets.insert(agg_o.clone(), targets); } } // check if mul_targets is empty assert!( !mul_targets.is_empty(), "The mul_targets is empty. Please check the input DataFrame." ); Ok(mul_targets) } /// Builds a map of the shape: /// `HashMap>` /// where only one of `weights` or `weight_xcats` can be provided. /// If neither is provided, weights default to 1.0. /// Each tuple is `(WeightValue, f64) = (weight, sign)`. fn form_weights_and_signs_map( cids: Vec, xcats: Vec, weights: Option>, weight_xcat: Option, signs: Option>, ) -> Result>, Box> { // For demonstration, we pretend to load or infer these from helpers: let agg_xcats_for_cid = agg_xcats_for_cid(cids.clone(), xcats.clone()); let (agg_on, agg_targ) = get_agg_on_agg_targs(cids.clone(), xcats.clone()); // Determine if each weight option has non-empty values. let weights_provided = weights.as_ref().map_or(false, |v| !v.is_empty()); let weight_xcats_provided = weight_xcat.as_ref().map_or(false, |v| !v.is_empty()); // Enforce that only one of weights or weight_xcats is specified. if weights_provided && weight_xcats_provided { return Err("Only one of `weights` and `weight_xcats` may be specified.".into()); } // 1) Build the "actual_weights" vector as WeightValue. let actual_weights: Vec = if weights_provided { weights.unwrap().into_iter().map(WeightValue::F64).collect() } else if weight_xcats_provided { vec![WeightValue::Str(weight_xcat.unwrap()); agg_targ.len()] } else { // Default to numeric 1.0 if neither is provided vec![WeightValue::F64(1.0); agg_targ.len()] }; // 2) Build the "signs" vector; default to 1.0 if not provided let signs = signs.unwrap_or_else(|| vec![1.0; agg_targ.len()]); // 3) Optional: check lengths & zero values (only numeric weights). check_weights_signs_lengths(&actual_weights, &signs, agg_xcats_for_cid, agg_targ.len())?; // 4) Build the final nested HashMap let mut weights_map: HashMap> = HashMap::new(); for agg_o in agg_on { let mut agg_t_map = HashMap::new(); for (i, agg_t) in agg_targ.iter().enumerate() { // Format the ticker let ticker = if agg_xcats_for_cid { format!("{}_{}", agg_o, agg_t) } else { format!("{}_{}", agg_t, agg_o) }; // Build the tuple (WeightValue, f64) let weight_sign_tuple = match &actual_weights[i] { WeightValue::F64(val) => (WeightValue::F64(*val).clone(), signs[i]), WeightValue::Str(vstr) => { let new_str = format!("{}_{}", agg_t, vstr); (WeightValue::Str(new_str), signs[i]) } }; agg_t_map.insert(ticker, weight_sign_tuple); } weights_map.insert(agg_o.clone(), agg_t_map); } Ok(weights_map) } /// Checks that the given slices have the expected length and that: /// - numeric weights are non-zero, /// - signs are non-zero. fn check_weights_signs_lengths( weights_vec: &[WeightValue], signs_vec: &[f64], agg_xcats_for_cid: bool, agg_targ_len: usize, ) -> Result<(), Box> { // For diagnostics, decide what to call the dimension let agg_targ = if agg_xcats_for_cid { "xcats" } else { "cids" }; // 1) Check numeric weights for zeroes. for (i, weight) in weights_vec.iter().enumerate() { if let WeightValue::F64(val) = weight { if *val == 0.0 { return Err(format!("The weight at index {} is 0.0", i).into()); } } } // 2) Ensure the weights vector is the expected length. if weights_vec.len() != agg_targ_len { return Err(format!( "The length of weights ({}) does not match the length of {} ({})", weights_vec.len(), agg_targ, agg_targ_len ) .into()); } // 3) Check signs for zero. for (i, sign) in signs_vec.iter().enumerate() { if *sign == 0.0 { return Err(format!("The sign at index {} is 0.0", i).into()); } } // 4) Ensure the signs vector is the expected length. if signs_vec.len() != agg_targ_len { return Err(format!( "The length of signs ({}) does not match the length of {} ({})", signs_vec.len(), agg_targ, agg_targ_len ) .into()); } Ok(()) } fn rename_result_dfw_cols( xcats: Vec, cids: Vec, new_xcat: String, new_cid: String, dfw: &mut DataFrame, ) -> Result<(), Box> { let dfw_cols: Vec = dfw .get_column_names() .iter() .filter(|&s| *s != "real_date") // Exclude 'real_date' .map(|s| s.to_string()) .collect(); let _agg_xcats_for_cid = agg_xcats_for_cid(cids.clone(), xcats.clone()); Ok(for col in dfw_cols { let new_name = match _agg_xcats_for_cid { true => format!("{}_{}", col, new_xcat), false => format!("{}_{}", new_cid, col), }; // rename the column dfw.rename(&col, new_name.into())?; }) } /// Flags if the xcats are aggregated for a given cid. /// If true, the xcats are aggregated for each cid, creating a new xcat. /// If false, the cids are aggregated for each xcat, creating a new cid. #[allow(unused_variables)] fn agg_xcats_for_cid(cids: Vec, xcats: Vec) -> bool { // if there is more than 1 xcat, return xcats.len() > 1 xcats.len() > 1 } /// Represents a weight value that can be a string, (float, or integer). #[derive(Debug, Clone, PartialEq)] pub enum WeightValue { Str(String), F64(f64), } impl From for WeightValue { fn from(s: String) -> Self { WeightValue::Str(s) } } impl<'a> From<&'a str> for WeightValue { fn from(s: &'a str) -> Self { WeightValue::Str(s.to_string()) } } impl From for WeightValue { fn from(f: f64) -> Self { WeightValue::F64(f) } } impl From for WeightValue { fn from(i: i32) -> Self { WeightValue::F64(i as f64) } } /// Weighted linear combinations of cross sections or categories /// # Arguments /// * `df` - QDF DataFrame /// * `xcats` - List of category names or a single category name /// * `cids` - List of cross section names or None /// * `weights` - List of weights or a string indicating a weight `xcat` /// * `normalize_weights` - Normalize weights to sum to 1 before applying /// * `signs` - List of signs for each category (+1 or -1) /// * `start` - Start date for the analysis /// * `end` - End date for the analysis /// * `blacklist` - Dictionary of blacklisted categories /// * `complete_xcats` - If True, complete xcats with missing values /// * `complete_cids` - If True, complete cids with missing values /// * `new_xcat` - Name of the new xcat /// * `new_cid` - Name of the new cid /// /// # Returns /// * `DataFrame` - DataFrame with the linear composite pub fn linear_composite( df: &DataFrame, xcats: Vec, cids: Vec, weights: Option>, signs: Option>, weight_xcat: Option, normalize_weights: bool, start: Option, end: Option, blacklist: Option>>, complete_xcats: bool, complete_cids: bool, new_xcat: Option, new_cid: Option, ) -> Result> { // Check if the DataFrame is a Quantamental DataFrame check_quantamental_dataframe(df)?; if agg_xcats_for_cid(cids.clone(), xcats.clone()) { if weight_xcat.is_some() { return Err( format!( "Using xcats as weights is not supported when aggregating cids for a single xcat. {:?} {:?}", cids, xcats ) .into(), ); } } let mut rxcats = xcats.clone(); if weight_xcat.is_some() { rxcats.extend(vec![weight_xcat.clone().unwrap()]); } let rdf = reduce_dataframe( df.clone(), Some(cids.clone()), Some(rxcats.clone()), Some(vec!["value".to_string()]), start.clone(), end.clone(), false, ) .unwrap(); let new_xcat = new_xcat.unwrap_or_else(|| "COMPOSITE".to_string()); let new_cid = new_cid.unwrap_or_else(|| "GLB".to_string()); let dfw = pivot_dataframe_by_ticker(rdf, Some("value".to_string())).unwrap(); let mul_targets = get_mul_targets(cids.clone(), xcats.clone())?; let weights_map = form_weights_and_signs_map(cids.clone(), xcats.clone(), weights, weight_xcat, signs)?; for (ticker, targets) in mul_targets.iter() { println!("ticker: {}, targets: {:?}", ticker, targets); } for (agg_on, agg_t_map) in weights_map.iter() { println!("agg_on: {}, agg_t_map: {:?}", agg_on, agg_t_map); } // assert!(0==1, "Debugging weights_map: {:?}", weights_map); // Perform the multiplication let complete = complete_xcats || complete_cids; let mut dfw = perform_multiplication( &dfw, &mul_targets, &weights_map, complete, normalize_weights, )?; rename_result_dfw_cols(xcats, cids, new_xcat, new_cid, &mut dfw)?; let dfw = pivot_wide_dataframe_to_qdf(dfw, Some("value".to_string()))?; Ok(dfw) }