adding linear_composite initial cut

This commit is contained in:
Palash Tyagi 2025-04-07 00:22:34 +01:00
parent acbe621704
commit 7c2ed3b818

View File

@ -0,0 +1,272 @@
use crate::utils::misc::*;
use crate::utils::qdf::check_quantamental_dataframe;
use crate::utils::qdf::pivots::*;
use crate::utils::qdf::reduce_df::*;
use chrono::NaiveDate;
use ndarray::Data;
use ndarray::{s, Array, Array1, Zip};
use polars::prelude::*;
use polars::series::Series; // Series struct
use std::collections::HashMap;
fn perform_single_group_agg(
dfw: &DataFrame,
new_dfw: &DataFrame,
agg_on: &String,
agg_targs: &Vec<String>,
agg_weights_map: &HashMap<String, Vec<f64>>,
normalize_weights: bool,
complete: bool,
) -> Result<DataFrame, PolarsError> {
// Replace this with the actual implementation
// get all agg_targs as columns
let mut weights_dfw = DataFrame::new(vec![])?; // Placeholder for weights DataFrame
for (agg_targ, weight_signs) in agg_weights_map.iter() {
let wgt = weight_signs[0] * weight_signs[1];
let wgt_series = Series::new(agg_targ.into(), vec![wgt]);
weights_dfw.with_column(wgt_series)?;
}
let mut data_dfw = DataFrame::new(vec![])?; // Placeholder for target DataFrame
for agg_targ in agg_targs {
if !dfw.get_column_names().contains(&&PlSmallStr::from_string(agg_targ.to_string())) {
continue;
}
let agg_targ_series = dfw.column(agg_targ)?.clone();
data_dfw.with_column(agg_targ_series)?;
}
// nan_mask = [iter over data_dfw.columns() applying is_nan()] OR [iter over data_dfw.rows() applying is_nan()]
let mut nan_mask = DataFrame::new(vec![])?; // Placeholder for NaN mask DataFrame
for col in data_dfw.get_column_names() {
let col_series = data_dfw.column(col)?;
let nan_mask_series = col_series.is_nan()?
.cast(&DataType::Boolean)?
.into_series();
nan_mask.with_column(nan_mask_series)?;
}
Ok(new_dfw.clone())
}
fn perform_multiplication(
dfw: &DataFrame,
mult_targets: &HashMap<String, Vec<String>>,
weights_map: &HashMap<String, HashMap<String, Vec<f64>>>,
complete: bool,
normalize_weights: bool,
) -> Result<DataFrame, PolarsError> {
let real_date = dfw.column("real_date")?.clone();
let mut new_dfw = DataFrame::new(vec![real_date])?;
for (agg_on, agg_targs) in mult_targets.iter() {
// perform_single_group_agg
perform_single_group_agg(
dfw,
&new_dfw,
agg_on,
agg_targs,
&weights_map[agg_on],
normalize_weights,
complete,
)?;
}
// Placeholder logic to return a valid DataFrame for now
// Replace this with the actual implementation
Ok(new_dfw)
}
fn get_agg_on_agg_targs(cids: Vec<String>, xcats: Vec<String>) -> (Vec<String>, Vec<String>) {
let _agg_xcats_for_cid = agg_xcats_for_cid(cids.clone(), xcats.clone());
let (agg_on, agg_targ) = if _agg_xcats_for_cid {
(cids.clone(), xcats.clone())
} else {
(xcats.clone(), cids.clone())
};
// assert that if agg_xcats_for_cid is true, agg_on = cids
match _agg_xcats_for_cid {
true => {
assert_eq!(agg_on, cids);
assert_eq!(agg_targ, xcats);
}
false => {
assert_eq!(agg_on, xcats);
assert_eq!(agg_targ, cids);
}
}
(agg_on, agg_targ)
}
/// Get the mapping of aggregation targets for the implied mode of aggregation.
/// # Returns
/// * `HashMap<String, Vec<String>>` - A mapping of cid/xcat to the list of tickers to be aggregated.
fn get_mul_targets(
cids: Vec<String>,
xcats: Vec<String>,
dfw: &DataFrame,
) -> Result<HashMap<String, Vec<String>>, Box<dyn std::error::Error>> {
let _agg_xcats_for_cid = agg_xcats_for_cid(cids.clone(), xcats.clone());
let mut mul_targets = HashMap::new();
let found_tickers = dfw
.get_column_names()
.iter()
.map(|name| name.to_string())
.collect::<Vec<String>>();
let (agg_on, agg_targ) = get_agg_on_agg_targs(cids.clone(), xcats.clone());
for agg_o in agg_on {
let mut targets = Vec::new();
for agg_t in &agg_targ {
let ticker = match _agg_xcats_for_cid {
true => format!("{}_{}", agg_t, agg_o),
false => format!("{}_{}", agg_o, agg_t),
};
if found_tickers.contains(&ticker) {
targets.push(ticker);
}
}
if !targets.is_empty() {
mul_targets.insert(agg_o.clone(), targets);
}
}
Ok(mul_targets)
}
fn form_weights_and_signs_map(
cids: Vec<String>,
xcats: Vec<String>,
weights: Option<Vec<f64>>,
signs: Option<Vec<f64>>,
) -> Result<HashMap<String, HashMap<String, Vec<f64>>>, Box<dyn std::error::Error>> {
let _agg_xcats_for_cid = agg_xcats_for_cid(cids.clone(), xcats.clone());
let (agg_on, agg_targ) = get_agg_on_agg_targs(cids.clone(), xcats.clone());
// if weights are None, create a vector of 1s of the same length as agg_targ
let weights = weights.unwrap_or(vec![1.0; agg_targ.len()]);
let signs = signs.unwrap_or(vec![1.0; agg_targ.len()]);
// check that the lengths of weights and signs match the length of agg_targ
check_weights_signs_lengths(
weights.clone(),
signs.clone(),
_agg_xcats_for_cid,
agg_targ.len(),
)?;
let mut weights_map = HashMap::new();
for agg_o in agg_on {
let mut agg_t_map = HashMap::new();
for (i, agg_t) in agg_targ.iter().enumerate() {
let ticker = match _agg_xcats_for_cid {
true => format!("{}_{}", agg_t, agg_o),
false => format!("{}_{}", agg_o, agg_t),
};
let weight_signs = vec![weights[i], signs[i]];
agg_t_map.insert(ticker, weight_signs);
}
weights_map.insert(agg_o.clone(), agg_t_map);
}
Ok(weights_map)
}
fn check_weights_signs_lengths<T>(
weights_vec: Vec<T>,
signs_vec: Vec<T>,
_agg_xcats_for_cid: bool,
agg_targ_len: usize,
) -> Result<(), Box<dyn std::error::Error>> {
// for vx, vname in ...
let agg_targ = match _agg_xcats_for_cid {
true => "xcats",
false => "cids",
};
for (vx, vname) in vec![(weights_vec.len(), "weights"), (signs_vec.len(), "signs")] {
if vx != agg_targ_len {
return Err(format!(
"The length of {} ({}) does not match the length of {} ({})",
vname, vx, agg_targ, agg_targ_len
)
.into());
}
}
Ok(())
}
/// Flags if the xcats are aggregated for a given cid.
/// If true, the xcats are aggregated for each cid, creating a new xcat.
/// If false, the cids are aggregated for each xcat, creating a new cid.
fn agg_xcats_for_cid(cids: Vec<String>, xcats: Vec<String>) -> bool {
// if there is more than 1 xcat, return xcats.len() > 1
xcats.len() > 1
}
/// Weighted linear combinations of cross sections or categories
/// # Arguments
/// * `df` - QDF DataFrame
/// * `xcats` - List of category names or a single category name
/// * `cids` - List of cross section names or None
/// * `weights` - List of weights or a string indicating a weight `xcat`
/// * `normalize_weights` - Normalize weights to sum to 1 before applying
/// * `signs` - List of signs for each category (+1 or -1)
/// * `start` - Start date for the analysis
/// * `end` - End date for the analysis
/// * `blacklist` - Dictionary of blacklisted categories
/// * `complete_xcats` - If True, complete xcats with missing values
/// * `complete_cids` - If True, complete cids with missing values
/// * `new_xcat` - Name of the new xcat
/// * `new_cid` - Name of the new cid
///
/// # Returns
/// * `DataFrame` - DataFrame with the linear composite
pub fn linear_composite(
df: &DataFrame,
xcats: Vec<String>,
cids: Vec<String>,
weights: Option<Vec<f64>>,
signs: Option<Vec<f64>>,
weight_xcats: Option<Vec<String>>,
normalize_weights: bool,
start: Option<String>,
end: Option<String>,
blacklist: Option<HashMap<String, Vec<String>>>,
complete_xcats: bool,
complete_cids: bool,
new_xcat: Option<String>,
new_cid: Option<String>,
) -> Result<DataFrame, Box<dyn std::error::Error>> {
// Check if the DataFrame is a Quantamental DataFrame
check_quantamental_dataframe(df)?;
let rdf = reduce_dataframe(
df.clone(),
Some(cids.clone()),
Some(xcats.clone()),
Some(vec!["value".to_string()]),
start.clone(),
end.clone(),
false,
)
.unwrap();
let mut dfw = pivot_dataframe_by_ticker(rdf, Some("value".to_string())).unwrap();
let mul_targets = get_mul_targets(cids.clone(), xcats.clone(), &dfw)?;
let weights_map = form_weights_and_signs_map(cids.clone(), xcats.clone(), weights, signs)?;
for (ticker, targets) in mul_targets.iter() {
println!("ticker: {}, targets: {:?}", ticker, targets);
}
for (agg_on, agg_t_map) in weights_map.iter() {
println!("agg_on: {}, agg_t_map: {:?}", agg_on, agg_t_map);
}
Ok(dfw)
}