msyrs/src/utils/qdf/reduce_df.rs
2024-11-17 23:58:47 +00:00

163 lines
5.3 KiB
Rust

use crate::utils::misc::*;
use crate::utils::qdf::core::*;
use polars::prelude::*;
use std::error::Error;
/// The required columns for a Quantamental DataFrame.
const QDF_INDEX_COLUMNS: [&str; 3] = ["real_date", "cid", "xcat"];
/// Filter a dataframe based on the given parameters.
/// - `cids`: Filter by cross-sectional identifiers
/// - `xcats`: Filter by extended categories
/// - `metrics`: Filter by metrics
/// - `start`: Filter by start date
/// - `end`: Filter by end date
/// - `intersect`: If true, intersect only return `cids` that are present for all `xcats`.
/// Returns a new DataFrame with the filtered data, without modifying the original DataFrame.
/// If no filters are provided, the original DataFrame is returned.
pub fn reduce_dataframe(
df: DataFrame,
cids: Option<Vec<String>>,
xcats: Option<Vec<String>>,
metrics: Option<Vec<String>>,
start: Option<String>,
end: Option<String>,
intersect: bool,
) -> Result<DataFrame, Box<dyn Error>> {
check_quantamental_dataframe(&df)?;
// df_size
let df_size = df.shape();
let mut new_df = df.clone();
let ticker_col: Column = get_ticker_column_for_quantamental_dataframe(&new_df)?;
// if cids is not provided, get all unique cids
let u_cids: Vec<String> = get_unique_cids(&new_df)?;
let u_xcats: Vec<String> = get_unique_xcats(&new_df)?;
let u_tickers: Vec<String> = _get_unique_strs_from_str_column_object(&ticker_col)?;
let cids_vec = cids.unwrap_or_else(|| u_cids.clone());
let specified_cids: Vec<&str> = cids_vec.iter().map(AsRef::as_ref).collect();
let xcats_vec = xcats.unwrap_or_else(|| u_xcats.clone());
let specified_xcats: Vec<&str> = xcats_vec.iter().map(AsRef::as_ref).collect();
let non_idx_cols: Vec<String> = new_df
.get_column_names()
.iter()
.filter(|&col| !QDF_INDEX_COLUMNS.contains(&col.as_str()))
.map(|s| s.to_string())
.collect();
let specified_metrics: Vec<String> =
metrics.unwrap_or_else(|| non_idx_cols.iter().map(|s| s.to_string()).collect());
let specified_tickers: Vec<String> = create_interesecting_tickers(
&specified_cids
.iter()
.map(AsRef::as_ref)
.collect::<Vec<&str>>(),
&specified_xcats
.iter()
.map(AsRef::as_ref)
.collect::<Vec<&str>>(),
);
let keep_tickers: Vec<String> = match intersect {
// true => get_intersecting_cids_str_func(&specified_cids, &specified_xcats, &u_tickers),
true => {
let int_cids = get_intersecting_cids_str_func(
&specified_cids
.iter()
.map(|&s| s.to_string())
.collect::<Vec<String>>(),
&specified_xcats
.iter()
.map(|&s| s.to_string())
.collect::<Vec<String>>(),
&u_tickers,
);
create_interesecting_tickers(
&int_cids.iter().map(AsRef::as_ref).collect::<Vec<&str>>(),
&specified_xcats
.iter()
.map(AsRef::as_ref)
.collect::<Vec<&str>>(),
)
}
false => specified_tickers.clone(),
};
let kticks: Vec<&str> = keep_tickers
.iter()
.map(AsRef::as_ref)
.collect::<Vec<&str>>();
// Create a boolean mask to filter rows based on the tickers
let mut mask = vec![false; ticker_col.len()];
for (i, ticker) in ticker_col.str()?.iter().enumerate() {
if let Some(t) = ticker {
if kticks.contains(&t) {
mask[i] = true;
}
}
}
let mask = BooleanChunked::from_slice("mask".into(), &mask);
new_df = new_df.filter(&mask)?;
// Apply date filtering if `start` or `end` is provided
if let Some(start) = start {
let start_date = chrono::NaiveDate::parse_from_str(&start, "%Y-%m-%d")?;
new_df = new_df
.lazy()
.filter(
col("real_date")
.gt_eq(lit(start_date))
.alias("real_date")
.into(),
)
.collect()?;
}
if let Some(end) = end {
let end_date = chrono::NaiveDate::parse_from_str(&end, "%Y-%m-%d")?;
new_df = new_df
.lazy()
.filter(
col("real_date")
.lt_eq(lit(end_date))
.alias("real_date")
.into(),
)
.collect()?;
}
// Filter based on metrics if provided
assert!(specified_metrics.len() > 0);
// remove columns that are not in the specified metrics
let mut cols_to_remove = Vec::new();
for col in new_df.get_column_names() {
if !specified_metrics.contains(&col.to_string())
&& !QDF_INDEX_COLUMNS.contains(&col.as_str())
{
cols_to_remove.push(col);
}
}
new_df = new_df.drop_many(
cols_to_remove
.iter()
.map(|s| s.to_string())
.collect::<Vec<String>>(),
);
// check if the df is still the same size
let new_df_size = new_df.shape();
if df_size != new_df_size {
println!(
"Reduced DataFrame from {} to {} rows",
df_size.0, new_df_size.0
);
}
Ok(new_df)
}