diff --git a/src/panel/linear_composite.rs b/src/panel/linear_composite.rs index 728b933..8126174 100644 --- a/src/panel/linear_composite.rs +++ b/src/panel/linear_composite.rs @@ -1,6 +1,6 @@ use crate::utils::qdf::check_quantamental_dataframe; -use crate::utils::qdf::pivots::*; -use crate::utils::qdf::reduce_df::*; +use crate::utils::qdf::pivots::{pivot_dataframe_by_ticker, pivot_wide_dataframe_to_qdf}; +use crate::utils::qdf::reduce_df::reduce_dataframe; use polars::prelude::*; use std::collections::HashMap; const TOLERANCE: f64 = 1e-8; @@ -108,14 +108,42 @@ fn _form_agg_nan_mask_series(nan_mask_dfw: &DataFrame) -> Result>, - data_dfw: DataFrame, + agg_weights_map: &HashMap, + dfw: &DataFrame, ) -> Result { let mut weights_dfw = DataFrame::new(vec![])?; for (agg_targ, weight_signs) in agg_weights_map.iter() { - let wgt = weight_signs[0] * weight_signs[1]; - let wgt_series = Series::new(agg_targ.into(), vec![wgt; data_dfw.height()]); + // let wgt = weight_signs[0] * weight_signs[1]; + let wgt_series = match &weight_signs.0 { + WeightValue::F64(val) => { + let wgt = val * weight_signs.1; + Series::new(agg_targ.into(), vec![wgt; dfw.height()]) + } + WeightValue::Str(vstr) => { + // vstr column from data_dfw, else raise wieght specification error + if !dfw.get_column_names().contains(&&PlSmallStr::from(vstr)) { + return Err(PolarsError::ComputeError( + format!( + "The column {} does not exist in the DataFrame. {:?}", + vstr, agg_weights_map + ) + .into(), + )); + } + let vstr_series = dfw.column(vstr)?; + let multiplied_series = vstr_series * weight_signs.1; + let mut multiplied_series = + multiplied_series.as_series().cloned().ok_or_else(|| { + PolarsError::ComputeError( + "Failed to convert multiplied_series to Series".into(), + ) + })?; + multiplied_series.rename(agg_targ.into()); + multiplied_series + } + }; weights_dfw.with_column(wgt_series)?; } Ok(weights_dfw) @@ -143,14 +171,14 @@ fn perform_single_group_agg( dfw: &DataFrame, agg_on: &String, agg_targs: &Vec, - agg_weights_map: &HashMap>, + agg_weights_map: &HashMap, normalize_weights: bool, complete: bool, ) -> Result { let data_dfw = _form_agg_data_dfw(dfw, agg_targs)?; let nan_mask_dfw = _form_agg_nan_mask_dfw(&data_dfw)?; let nan_mask_series = _form_agg_nan_mask_series(&nan_mask_dfw)?; - let weights_dfw = _form_agg_weights_dfw(agg_weights_map, data_dfw.clone())?; + let weights_dfw = _form_agg_weights_dfw(agg_weights_map, dfw)?; let weights_dfw = match normalize_weights { true => normalize_weights_with_nan_mask(weights_dfw, nan_mask_dfw)?, false => weights_dfw, @@ -192,7 +220,7 @@ fn perform_single_group_agg( fn perform_multiplication( dfw: &DataFrame, mult_targets: &HashMap>, - weights_map: &HashMap>>, + weights_map: &HashMap>, complete: bool, normalize_weights: bool, ) -> Result { @@ -200,6 +228,7 @@ fn perform_multiplication( // let mut new_dfw = DataFrame::new(vec![real_date])?; let mut new_dfw = DataFrame::new(vec![])?; assert!(!mult_targets.is_empty(), "agg_targs is empty"); + for (agg_on, agg_targs) in mult_targets.iter() { // perform_single_group_agg let cols_len = new_dfw.get_column_names().len(); @@ -288,76 +317,122 @@ fn get_mul_targets( Ok(mul_targets) } +/// Builds a map of the shape: +/// `HashMap>` +/// where only one of `weights` or `weight_xcats` can be provided. +/// If neither is provided, weights default to 1.0. +/// Each tuple is `(WeightValue, f64) = (weight, sign)`. fn form_weights_and_signs_map( cids: Vec, xcats: Vec, weights: Option>, + weight_xcat: Option, signs: Option>, -) -> Result>>, Box> { - let _agg_xcats_for_cid = agg_xcats_for_cid(cids.clone(), xcats.clone()); - +) -> Result>, Box> { + // For demonstration, we pretend to load or infer these from helpers: + let agg_xcats_for_cid = agg_xcats_for_cid(cids.clone(), xcats.clone()); let (agg_on, agg_targ) = get_agg_on_agg_targs(cids.clone(), xcats.clone()); - // if weights are None, create a vector of 1s of the same length as agg_targ - let weights = weights.unwrap_or(vec![1.0 / agg_targ.len() as f64; agg_targ.len()]); - let signs = signs.unwrap_or(vec![1.0; agg_targ.len()]); + // Determine if each weight option has non-empty values. + let weights_provided = weights.as_ref().map_or(false, |v| !v.is_empty()); + let weight_xcats_provided = weight_xcat.as_ref().map_or(false, |v| !v.is_empty()); - // check that the lengths of weights and signs match the length of agg_targ - check_weights_signs_lengths( - weights.clone(), - signs.clone(), - _agg_xcats_for_cid, - agg_targ.len(), - )?; + // Enforce that only one of weights or weight_xcats is specified. + if weights_provided && weight_xcats_provided { + return Err("Only one of `weights` and `weight_xcats` may be specified.".into()); + } - let mut weights_map = HashMap::new(); + // 1) Build the "actual_weights" vector as WeightValue. + let actual_weights: Vec = if weights_provided { + weights.unwrap().into_iter().map(WeightValue::F64).collect() + } else if weight_xcats_provided { + vec![WeightValue::Str(weight_xcat.unwrap()); agg_targ.len()] + } else { + // Default to numeric 1.0 if neither is provided + vec![WeightValue::F64(1.0); agg_targ.len()] + }; + + // 2) Build the "signs" vector; default to 1.0 if not provided + let signs = signs.unwrap_or_else(|| vec![1.0; agg_targ.len()]); + + // 3) Optional: check lengths & zero values (only numeric weights). + check_weights_signs_lengths(&actual_weights, &signs, agg_xcats_for_cid, agg_targ.len())?; + + // 4) Build the final nested HashMap + let mut weights_map: HashMap> = HashMap::new(); for agg_o in agg_on { let mut agg_t_map = HashMap::new(); for (i, agg_t) in agg_targ.iter().enumerate() { - let ticker = match _agg_xcats_for_cid { - true => format!("{}_{}", agg_o, agg_t), - false => format!("{}_{}", agg_t, agg_o), + // Format the ticker + let ticker = if agg_xcats_for_cid { + format!("{}_{}", agg_o, agg_t) + } else { + format!("{}_{}", agg_t, agg_o) }; - let weight_signs = vec![weights[i], signs[i]]; - agg_t_map.insert(ticker, weight_signs); + // Build the tuple (WeightValue, f64) + let weight_sign_tuple = match &actual_weights[i] { + WeightValue::F64(val) => (WeightValue::F64(*val).clone(), signs[i]), + WeightValue::Str(vstr) => { + let new_str = format!("{}_{}", agg_t, vstr); + (WeightValue::Str(new_str), signs[i]) + } + }; + agg_t_map.insert(ticker, weight_sign_tuple); } weights_map.insert(agg_o.clone(), agg_t_map); } + Ok(weights_map) } - +/// Checks that the given slices have the expected length and that: +/// - numeric weights are non-zero, +/// - signs are non-zero. fn check_weights_signs_lengths( - weights_vec: Vec, - signs_vec: Vec, - _agg_xcats_for_cid: bool, + weights_vec: &[WeightValue], + signs_vec: &[f64], + agg_xcats_for_cid: bool, agg_targ_len: usize, ) -> Result<(), Box> { - // for vx, vname in ... - let agg_targ = match _agg_xcats_for_cid { - true => "xcats", - false => "cids", - }; - for (vx, vname) in vec![ - (weights_vec.clone(), "weights"), - (signs_vec.clone(), "signs"), - ] { - for (i, v) in vx.iter().enumerate() { - if *v == 0.0 { - return Err(format!("The {} at index {} is 0.0", vname, i).into()); + // For diagnostics, decide what to call the dimension + let agg_targ = if agg_xcats_for_cid { "xcats" } else { "cids" }; + + // 1) Check numeric weights for zeroes. + for (i, weight) in weights_vec.iter().enumerate() { + if let WeightValue::F64(val) = weight { + if *val == 0.0 { + return Err(format!("The weight at index {} is 0.0", i).into()); } } - if vx.len() != agg_targ_len { - return Err(format!( - "The length of {} ({}) does not match the length of {} ({})", - vname, - vx.len(), - agg_targ, - agg_targ_len - ) - .into()); + } + // 2) Ensure the weights vector is the expected length. + if weights_vec.len() != agg_targ_len { + return Err(format!( + "The length of weights ({}) does not match the length of {} ({})", + weights_vec.len(), + agg_targ, + agg_targ_len + ) + .into()); + } + + // 3) Check signs for zero. + for (i, sign) in signs_vec.iter().enumerate() { + if *sign == 0.0 { + return Err(format!("The sign at index {} is 0.0", i).into()); } } + // 4) Ensure the signs vector is the expected length. + if signs_vec.len() != agg_targ_len { + return Err(format!( + "The length of signs ({}) does not match the length of {} ({})", + signs_vec.len(), + agg_targ, + agg_targ_len + ) + .into()); + } + Ok(()) } fn rename_result_dfw_cols( @@ -393,6 +468,36 @@ fn agg_xcats_for_cid(cids: Vec, xcats: Vec) -> bool { xcats.len() > 1 } +/// Represents a weight value that can be a string, (float, or integer). +#[derive(Debug, Clone, PartialEq)] +pub enum WeightValue { + Str(String), + F64(f64), +} +impl From for WeightValue { + fn from(s: String) -> Self { + WeightValue::Str(s) + } +} + +impl<'a> From<&'a str> for WeightValue { + fn from(s: &'a str) -> Self { + WeightValue::Str(s.to_string()) + } +} + +impl From for WeightValue { + fn from(f: f64) -> Self { + WeightValue::F64(f) + } +} + +impl From for WeightValue { + fn from(i: i32) -> Self { + WeightValue::F64(i as f64) + } +} + /// Weighted linear combinations of cross sections or categories /// # Arguments /// * `df` - QDF DataFrame @@ -417,7 +522,7 @@ pub fn linear_composite( cids: Vec, weights: Option>, signs: Option>, - weight_xcats: Option>, + weight_xcat: Option, normalize_weights: bool, start: Option, end: Option, @@ -429,10 +534,28 @@ pub fn linear_composite( ) -> Result> { // Check if the DataFrame is a Quantamental DataFrame check_quantamental_dataframe(df)?; + + if agg_xcats_for_cid(cids.clone(), xcats.clone()) { + if weight_xcat.is_some() { + return Err( + format!( + "Using xcats as weights is not supported when aggregating cids for a single xcat. {:?} {:?}", + cids, xcats + ) + .into(), + ); + } + } + + let mut rxcats = xcats.clone(); + if weight_xcat.is_some() { + rxcats.extend(vec![weight_xcat.clone().unwrap()]); + } + let rdf = reduce_dataframe( df.clone(), Some(cids.clone()), - Some(xcats.clone()), + Some(rxcats.clone()), Some(vec!["value".to_string()]), start.clone(), end.clone(), @@ -443,10 +566,11 @@ pub fn linear_composite( let new_xcat = new_xcat.unwrap_or_else(|| "COMPOSITE".to_string()); let new_cid = new_cid.unwrap_or_else(|| "GLB".to_string()); - let dfw = pivot_dataframe_by_ticker(rdf.clone(), Some("value".to_string())).unwrap(); + let dfw = pivot_dataframe_by_ticker(rdf, Some("value".to_string())).unwrap(); let mul_targets = get_mul_targets(cids.clone(), xcats.clone())?; - let weights_map = form_weights_and_signs_map(cids.clone(), xcats.clone(), weights, signs)?; + let weights_map = + form_weights_and_signs_map(cids.clone(), xcats.clone(), weights, weight_xcat, signs)?; for (ticker, targets) in mul_targets.iter() { println!("ticker: {}, targets: {:?}", ticker, targets);