mirror of
https://github.com/Magnus167/msyrs.git
synced 2025-08-20 13:10:00 +00:00
adding linear_composite initial cut
This commit is contained in:
parent
acbe621704
commit
7c2ed3b818
@ -0,0 +1,272 @@
|
||||
use crate::utils::misc::*;
|
||||
use crate::utils::qdf::check_quantamental_dataframe;
|
||||
use crate::utils::qdf::pivots::*;
|
||||
use crate::utils::qdf::reduce_df::*;
|
||||
use chrono::NaiveDate;
|
||||
use ndarray::Data;
|
||||
use ndarray::{s, Array, Array1, Zip};
|
||||
use polars::prelude::*;
|
||||
use polars::series::Series; // Series struct
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn perform_single_group_agg(
|
||||
dfw: &DataFrame,
|
||||
new_dfw: &DataFrame,
|
||||
agg_on: &String,
|
||||
agg_targs: &Vec<String>,
|
||||
agg_weights_map: &HashMap<String, Vec<f64>>,
|
||||
normalize_weights: bool,
|
||||
complete: bool,
|
||||
) -> Result<DataFrame, PolarsError> {
|
||||
// Replace this with the actual implementation
|
||||
// get all agg_targs as columns
|
||||
|
||||
let mut weights_dfw = DataFrame::new(vec![])?; // Placeholder for weights DataFrame
|
||||
for (agg_targ, weight_signs) in agg_weights_map.iter() {
|
||||
let wgt = weight_signs[0] * weight_signs[1];
|
||||
let wgt_series = Series::new(agg_targ.into(), vec![wgt]);
|
||||
weights_dfw.with_column(wgt_series)?;
|
||||
}
|
||||
|
||||
let mut data_dfw = DataFrame::new(vec![])?; // Placeholder for target DataFrame
|
||||
for agg_targ in agg_targs {
|
||||
if !dfw.get_column_names().contains(&&PlSmallStr::from_string(agg_targ.to_string())) {
|
||||
continue;
|
||||
}
|
||||
let agg_targ_series = dfw.column(agg_targ)?.clone();
|
||||
data_dfw.with_column(agg_targ_series)?;
|
||||
}
|
||||
|
||||
|
||||
// nan_mask = [iter over data_dfw.columns() applying is_nan()] OR [iter over data_dfw.rows() applying is_nan()]
|
||||
let mut nan_mask = DataFrame::new(vec![])?; // Placeholder for NaN mask DataFrame
|
||||
for col in data_dfw.get_column_names() {
|
||||
let col_series = data_dfw.column(col)?;
|
||||
let nan_mask_series = col_series.is_nan()?
|
||||
.cast(&DataType::Boolean)?
|
||||
.into_series();
|
||||
nan_mask.with_column(nan_mask_series)?;
|
||||
}
|
||||
|
||||
|
||||
|
||||
Ok(new_dfw.clone())
|
||||
}
|
||||
|
||||
fn perform_multiplication(
|
||||
dfw: &DataFrame,
|
||||
mult_targets: &HashMap<String, Vec<String>>,
|
||||
weights_map: &HashMap<String, HashMap<String, Vec<f64>>>,
|
||||
complete: bool,
|
||||
normalize_weights: bool,
|
||||
) -> Result<DataFrame, PolarsError> {
|
||||
let real_date = dfw.column("real_date")?.clone();
|
||||
let mut new_dfw = DataFrame::new(vec![real_date])?;
|
||||
|
||||
for (agg_on, agg_targs) in mult_targets.iter() {
|
||||
// perform_single_group_agg
|
||||
perform_single_group_agg(
|
||||
dfw,
|
||||
&new_dfw,
|
||||
agg_on,
|
||||
agg_targs,
|
||||
&weights_map[agg_on],
|
||||
normalize_weights,
|
||||
complete,
|
||||
)?;
|
||||
}
|
||||
|
||||
// Placeholder logic to return a valid DataFrame for now
|
||||
// Replace this with the actual implementation
|
||||
Ok(new_dfw)
|
||||
}
|
||||
|
||||
fn get_agg_on_agg_targs(cids: Vec<String>, xcats: Vec<String>) -> (Vec<String>, Vec<String>) {
|
||||
let _agg_xcats_for_cid = agg_xcats_for_cid(cids.clone(), xcats.clone());
|
||||
let (agg_on, agg_targ) = if _agg_xcats_for_cid {
|
||||
(cids.clone(), xcats.clone())
|
||||
} else {
|
||||
(xcats.clone(), cids.clone())
|
||||
};
|
||||
// assert that if agg_xcats_for_cid is true, agg_on = cids
|
||||
match _agg_xcats_for_cid {
|
||||
true => {
|
||||
assert_eq!(agg_on, cids);
|
||||
assert_eq!(agg_targ, xcats);
|
||||
}
|
||||
false => {
|
||||
assert_eq!(agg_on, xcats);
|
||||
assert_eq!(agg_targ, cids);
|
||||
}
|
||||
}
|
||||
(agg_on, agg_targ)
|
||||
}
|
||||
|
||||
/// Get the mapping of aggregation targets for the implied mode of aggregation.
|
||||
/// # Returns
|
||||
/// * `HashMap<String, Vec<String>>` - A mapping of cid/xcat to the list of tickers to be aggregated.
|
||||
fn get_mul_targets(
|
||||
cids: Vec<String>,
|
||||
xcats: Vec<String>,
|
||||
dfw: &DataFrame,
|
||||
) -> Result<HashMap<String, Vec<String>>, Box<dyn std::error::Error>> {
|
||||
let _agg_xcats_for_cid = agg_xcats_for_cid(cids.clone(), xcats.clone());
|
||||
let mut mul_targets = HashMap::new();
|
||||
|
||||
let found_tickers = dfw
|
||||
.get_column_names()
|
||||
.iter()
|
||||
.map(|name| name.to_string())
|
||||
.collect::<Vec<String>>();
|
||||
|
||||
let (agg_on, agg_targ) = get_agg_on_agg_targs(cids.clone(), xcats.clone());
|
||||
|
||||
for agg_o in agg_on {
|
||||
let mut targets = Vec::new();
|
||||
for agg_t in &agg_targ {
|
||||
let ticker = match _agg_xcats_for_cid {
|
||||
true => format!("{}_{}", agg_t, agg_o),
|
||||
false => format!("{}_{}", agg_o, agg_t),
|
||||
};
|
||||
if found_tickers.contains(&ticker) {
|
||||
targets.push(ticker);
|
||||
}
|
||||
}
|
||||
if !targets.is_empty() {
|
||||
mul_targets.insert(agg_o.clone(), targets);
|
||||
}
|
||||
}
|
||||
Ok(mul_targets)
|
||||
}
|
||||
|
||||
fn form_weights_and_signs_map(
|
||||
cids: Vec<String>,
|
||||
xcats: Vec<String>,
|
||||
weights: Option<Vec<f64>>,
|
||||
signs: Option<Vec<f64>>,
|
||||
) -> Result<HashMap<String, HashMap<String, Vec<f64>>>, Box<dyn std::error::Error>> {
|
||||
let _agg_xcats_for_cid = agg_xcats_for_cid(cids.clone(), xcats.clone());
|
||||
|
||||
let (agg_on, agg_targ) = get_agg_on_agg_targs(cids.clone(), xcats.clone());
|
||||
|
||||
// if weights are None, create a vector of 1s of the same length as agg_targ
|
||||
let weights = weights.unwrap_or(vec![1.0; agg_targ.len()]);
|
||||
let signs = signs.unwrap_or(vec![1.0; agg_targ.len()]);
|
||||
|
||||
// check that the lengths of weights and signs match the length of agg_targ
|
||||
check_weights_signs_lengths(
|
||||
weights.clone(),
|
||||
signs.clone(),
|
||||
_agg_xcats_for_cid,
|
||||
agg_targ.len(),
|
||||
)?;
|
||||
|
||||
let mut weights_map = HashMap::new();
|
||||
|
||||
for agg_o in agg_on {
|
||||
let mut agg_t_map = HashMap::new();
|
||||
for (i, agg_t) in agg_targ.iter().enumerate() {
|
||||
let ticker = match _agg_xcats_for_cid {
|
||||
true => format!("{}_{}", agg_t, agg_o),
|
||||
false => format!("{}_{}", agg_o, agg_t),
|
||||
};
|
||||
let weight_signs = vec![weights[i], signs[i]];
|
||||
agg_t_map.insert(ticker, weight_signs);
|
||||
}
|
||||
weights_map.insert(agg_o.clone(), agg_t_map);
|
||||
}
|
||||
Ok(weights_map)
|
||||
}
|
||||
|
||||
fn check_weights_signs_lengths<T>(
|
||||
weights_vec: Vec<T>,
|
||||
signs_vec: Vec<T>,
|
||||
_agg_xcats_for_cid: bool,
|
||||
agg_targ_len: usize,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
// for vx, vname in ...
|
||||
let agg_targ = match _agg_xcats_for_cid {
|
||||
true => "xcats",
|
||||
false => "cids",
|
||||
};
|
||||
for (vx, vname) in vec![(weights_vec.len(), "weights"), (signs_vec.len(), "signs")] {
|
||||
if vx != agg_targ_len {
|
||||
return Err(format!(
|
||||
"The length of {} ({}) does not match the length of {} ({})",
|
||||
vname, vx, agg_targ, agg_targ_len
|
||||
)
|
||||
.into());
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Flags if the xcats are aggregated for a given cid.
|
||||
/// If true, the xcats are aggregated for each cid, creating a new xcat.
|
||||
/// If false, the cids are aggregated for each xcat, creating a new cid.
|
||||
fn agg_xcats_for_cid(cids: Vec<String>, xcats: Vec<String>) -> bool {
|
||||
// if there is more than 1 xcat, return xcats.len() > 1
|
||||
xcats.len() > 1
|
||||
}
|
||||
|
||||
/// Weighted linear combinations of cross sections or categories
|
||||
/// # Arguments
|
||||
/// * `df` - QDF DataFrame
|
||||
/// * `xcats` - List of category names or a single category name
|
||||
/// * `cids` - List of cross section names or None
|
||||
/// * `weights` - List of weights or a string indicating a weight `xcat`
|
||||
/// * `normalize_weights` - Normalize weights to sum to 1 before applying
|
||||
/// * `signs` - List of signs for each category (+1 or -1)
|
||||
/// * `start` - Start date for the analysis
|
||||
/// * `end` - End date for the analysis
|
||||
/// * `blacklist` - Dictionary of blacklisted categories
|
||||
/// * `complete_xcats` - If True, complete xcats with missing values
|
||||
/// * `complete_cids` - If True, complete cids with missing values
|
||||
/// * `new_xcat` - Name of the new xcat
|
||||
/// * `new_cid` - Name of the new cid
|
||||
///
|
||||
/// # Returns
|
||||
/// * `DataFrame` - DataFrame with the linear composite
|
||||
pub fn linear_composite(
|
||||
df: &DataFrame,
|
||||
xcats: Vec<String>,
|
||||
cids: Vec<String>,
|
||||
weights: Option<Vec<f64>>,
|
||||
signs: Option<Vec<f64>>,
|
||||
weight_xcats: Option<Vec<String>>,
|
||||
normalize_weights: bool,
|
||||
start: Option<String>,
|
||||
end: Option<String>,
|
||||
blacklist: Option<HashMap<String, Vec<String>>>,
|
||||
complete_xcats: bool,
|
||||
complete_cids: bool,
|
||||
new_xcat: Option<String>,
|
||||
new_cid: Option<String>,
|
||||
) -> Result<DataFrame, Box<dyn std::error::Error>> {
|
||||
// Check if the DataFrame is a Quantamental DataFrame
|
||||
check_quantamental_dataframe(df)?;
|
||||
let rdf = reduce_dataframe(
|
||||
df.clone(),
|
||||
Some(cids.clone()),
|
||||
Some(xcats.clone()),
|
||||
Some(vec!["value".to_string()]),
|
||||
start.clone(),
|
||||
end.clone(),
|
||||
false,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let mut dfw = pivot_dataframe_by_ticker(rdf, Some("value".to_string())).unwrap();
|
||||
|
||||
let mul_targets = get_mul_targets(cids.clone(), xcats.clone(), &dfw)?;
|
||||
let weights_map = form_weights_and_signs_map(cids.clone(), xcats.clone(), weights, signs)?;
|
||||
|
||||
for (ticker, targets) in mul_targets.iter() {
|
||||
println!("ticker: {}, targets: {:?}", ticker, targets);
|
||||
}
|
||||
for (agg_on, agg_t_map) in weights_map.iter() {
|
||||
println!("agg_on: {}, agg_t_map: {:?}", agg_on, agg_t_map);
|
||||
}
|
||||
|
||||
Ok(dfw)
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user