diff --git a/src/panel/linear_composite.rs b/src/panel/linear_composite.rs index e69de29..39a34e1 100644 --- a/src/panel/linear_composite.rs +++ b/src/panel/linear_composite.rs @@ -0,0 +1,272 @@ +use crate::utils::misc::*; +use crate::utils::qdf::check_quantamental_dataframe; +use crate::utils::qdf::pivots::*; +use crate::utils::qdf::reduce_df::*; +use chrono::NaiveDate; +use ndarray::Data; +use ndarray::{s, Array, Array1, Zip}; +use polars::prelude::*; +use polars::series::Series; // Series struct +use std::collections::HashMap; + +fn perform_single_group_agg( + dfw: &DataFrame, + new_dfw: &DataFrame, + agg_on: &String, + agg_targs: &Vec, + agg_weights_map: &HashMap>, + normalize_weights: bool, + complete: bool, +) -> Result { + // Replace this with the actual implementation + // get all agg_targs as columns + + let mut weights_dfw = DataFrame::new(vec![])?; // Placeholder for weights DataFrame + for (agg_targ, weight_signs) in agg_weights_map.iter() { + let wgt = weight_signs[0] * weight_signs[1]; + let wgt_series = Series::new(agg_targ.into(), vec![wgt]); + weights_dfw.with_column(wgt_series)?; + } + + let mut data_dfw = DataFrame::new(vec![])?; // Placeholder for target DataFrame + for agg_targ in agg_targs { + if !dfw.get_column_names().contains(&&PlSmallStr::from_string(agg_targ.to_string())) { + continue; + } + let agg_targ_series = dfw.column(agg_targ)?.clone(); + data_dfw.with_column(agg_targ_series)?; + } + + + // nan_mask = [iter over data_dfw.columns() applying is_nan()] OR [iter over data_dfw.rows() applying is_nan()] + let mut nan_mask = DataFrame::new(vec![])?; // Placeholder for NaN mask DataFrame + for col in data_dfw.get_column_names() { + let col_series = data_dfw.column(col)?; + let nan_mask_series = col_series.is_nan()? + .cast(&DataType::Boolean)? + .into_series(); + nan_mask.with_column(nan_mask_series)?; + } + + + + Ok(new_dfw.clone()) +} + +fn perform_multiplication( + dfw: &DataFrame, + mult_targets: &HashMap>, + weights_map: &HashMap>>, + complete: bool, + normalize_weights: bool, +) -> Result { + let real_date = dfw.column("real_date")?.clone(); + let mut new_dfw = DataFrame::new(vec![real_date])?; + + for (agg_on, agg_targs) in mult_targets.iter() { + // perform_single_group_agg + perform_single_group_agg( + dfw, + &new_dfw, + agg_on, + agg_targs, + &weights_map[agg_on], + normalize_weights, + complete, + )?; + } + + // Placeholder logic to return a valid DataFrame for now + // Replace this with the actual implementation + Ok(new_dfw) +} + +fn get_agg_on_agg_targs(cids: Vec, xcats: Vec) -> (Vec, Vec) { + let _agg_xcats_for_cid = agg_xcats_for_cid(cids.clone(), xcats.clone()); + let (agg_on, agg_targ) = if _agg_xcats_for_cid { + (cids.clone(), xcats.clone()) + } else { + (xcats.clone(), cids.clone()) + }; + // assert that if agg_xcats_for_cid is true, agg_on = cids + match _agg_xcats_for_cid { + true => { + assert_eq!(agg_on, cids); + assert_eq!(agg_targ, xcats); + } + false => { + assert_eq!(agg_on, xcats); + assert_eq!(agg_targ, cids); + } + } + (agg_on, agg_targ) +} + +/// Get the mapping of aggregation targets for the implied mode of aggregation. +/// # Returns +/// * `HashMap>` - A mapping of cid/xcat to the list of tickers to be aggregated. +fn get_mul_targets( + cids: Vec, + xcats: Vec, + dfw: &DataFrame, +) -> Result>, Box> { + let _agg_xcats_for_cid = agg_xcats_for_cid(cids.clone(), xcats.clone()); + let mut mul_targets = HashMap::new(); + + let found_tickers = dfw + .get_column_names() + .iter() + .map(|name| name.to_string()) + .collect::>(); + + let (agg_on, agg_targ) = get_agg_on_agg_targs(cids.clone(), xcats.clone()); + + for agg_o in agg_on { + let mut targets = Vec::new(); + for agg_t in &agg_targ { + let ticker = match _agg_xcats_for_cid { + true => format!("{}_{}", agg_t, agg_o), + false => format!("{}_{}", agg_o, agg_t), + }; + if found_tickers.contains(&ticker) { + targets.push(ticker); + } + } + if !targets.is_empty() { + mul_targets.insert(agg_o.clone(), targets); + } + } + Ok(mul_targets) +} + +fn form_weights_and_signs_map( + cids: Vec, + xcats: Vec, + weights: Option>, + signs: Option>, +) -> Result>>, Box> { + let _agg_xcats_for_cid = agg_xcats_for_cid(cids.clone(), xcats.clone()); + + let (agg_on, agg_targ) = get_agg_on_agg_targs(cids.clone(), xcats.clone()); + + // if weights are None, create a vector of 1s of the same length as agg_targ + let weights = weights.unwrap_or(vec![1.0; agg_targ.len()]); + let signs = signs.unwrap_or(vec![1.0; agg_targ.len()]); + + // check that the lengths of weights and signs match the length of agg_targ + check_weights_signs_lengths( + weights.clone(), + signs.clone(), + _agg_xcats_for_cid, + agg_targ.len(), + )?; + + let mut weights_map = HashMap::new(); + + for agg_o in agg_on { + let mut agg_t_map = HashMap::new(); + for (i, agg_t) in agg_targ.iter().enumerate() { + let ticker = match _agg_xcats_for_cid { + true => format!("{}_{}", agg_t, agg_o), + false => format!("{}_{}", agg_o, agg_t), + }; + let weight_signs = vec![weights[i], signs[i]]; + agg_t_map.insert(ticker, weight_signs); + } + weights_map.insert(agg_o.clone(), agg_t_map); + } + Ok(weights_map) +} + +fn check_weights_signs_lengths( + weights_vec: Vec, + signs_vec: Vec, + _agg_xcats_for_cid: bool, + agg_targ_len: usize, +) -> Result<(), Box> { + // for vx, vname in ... + let agg_targ = match _agg_xcats_for_cid { + true => "xcats", + false => "cids", + }; + for (vx, vname) in vec![(weights_vec.len(), "weights"), (signs_vec.len(), "signs")] { + if vx != agg_targ_len { + return Err(format!( + "The length of {} ({}) does not match the length of {} ({})", + vname, vx, agg_targ, agg_targ_len + ) + .into()); + } + } + Ok(()) +} + +/// Flags if the xcats are aggregated for a given cid. +/// If true, the xcats are aggregated for each cid, creating a new xcat. +/// If false, the cids are aggregated for each xcat, creating a new cid. +fn agg_xcats_for_cid(cids: Vec, xcats: Vec) -> bool { + // if there is more than 1 xcat, return xcats.len() > 1 + xcats.len() > 1 +} + +/// Weighted linear combinations of cross sections or categories +/// # Arguments +/// * `df` - QDF DataFrame +/// * `xcats` - List of category names or a single category name +/// * `cids` - List of cross section names or None +/// * `weights` - List of weights or a string indicating a weight `xcat` +/// * `normalize_weights` - Normalize weights to sum to 1 before applying +/// * `signs` - List of signs for each category (+1 or -1) +/// * `start` - Start date for the analysis +/// * `end` - End date for the analysis +/// * `blacklist` - Dictionary of blacklisted categories +/// * `complete_xcats` - If True, complete xcats with missing values +/// * `complete_cids` - If True, complete cids with missing values +/// * `new_xcat` - Name of the new xcat +/// * `new_cid` - Name of the new cid +/// +/// # Returns +/// * `DataFrame` - DataFrame with the linear composite +pub fn linear_composite( + df: &DataFrame, + xcats: Vec, + cids: Vec, + weights: Option>, + signs: Option>, + weight_xcats: Option>, + normalize_weights: bool, + start: Option, + end: Option, + blacklist: Option>>, + complete_xcats: bool, + complete_cids: bool, + new_xcat: Option, + new_cid: Option, +) -> Result> { + // Check if the DataFrame is a Quantamental DataFrame + check_quantamental_dataframe(df)?; + let rdf = reduce_dataframe( + df.clone(), + Some(cids.clone()), + Some(xcats.clone()), + Some(vec!["value".to_string()]), + start.clone(), + end.clone(), + false, + ) + .unwrap(); + + let mut dfw = pivot_dataframe_by_ticker(rdf, Some("value".to_string())).unwrap(); + + let mul_targets = get_mul_targets(cids.clone(), xcats.clone(), &dfw)?; + let weights_map = form_weights_and_signs_map(cids.clone(), xcats.clone(), weights, signs)?; + + for (ticker, targets) in mul_targets.iter() { + println!("ticker: {}, targets: {:?}", ticker, targets); + } + for (agg_on, agg_t_map) in weights_map.iter() { + println!("agg_on: {}, agg_t_map: {:?}", agg_on, agg_t_map); + } + + Ok(dfw) +}