mirror of
https://github.com/Magnus167/msyrs.git
synced 2025-08-20 07:20:01 +00:00
164 lines
5.3 KiB
Rust
164 lines
5.3 KiB
Rust
use crate::utils::misc::*;
|
|
use crate::utils::qdf::core::*;
|
|
use polars::prelude::*;
|
|
use std::error::Error;
|
|
|
|
/// The required columns for a Quantamental DataFrame.
|
|
const QDF_INDEX_COLUMNS: [&str; 3] = ["real_date", "cid", "xcat"];
|
|
|
|
/// Filter a dataframe based on the given parameters.
|
|
/// # Arguments:
|
|
/// - `cids`: Filter by cross-sectional identifiers
|
|
/// - `xcats`: Filter by extended categories
|
|
/// - `metrics`: Filter by metrics
|
|
/// - `start`: Filter by start date
|
|
/// - `end`: Filter by end date
|
|
/// - `intersect`: If true, intersect only return `cids` that are present for all `xcats`.
|
|
/// Returns a new DataFrame with the filtered data, without modifying the original DataFrame.
|
|
/// If no filters are provided, the original DataFrame is returned.
|
|
pub fn reduce_dataframe(
|
|
df: DataFrame,
|
|
cids: Option<Vec<String>>,
|
|
xcats: Option<Vec<String>>,
|
|
metrics: Option<Vec<String>>,
|
|
start: Option<String>,
|
|
end: Option<String>,
|
|
intersect: bool,
|
|
) -> Result<DataFrame, Box<dyn Error>> {
|
|
check_quantamental_dataframe(&df)?;
|
|
// df_size
|
|
let df_size = df.shape();
|
|
let mut new_df = df.clone();
|
|
|
|
let ticker_col = get_ticker_column_for_quantamental_dataframe(&new_df)?;
|
|
|
|
// if cids is not provided, get all unique cids
|
|
let u_cids = get_unique_cids(&new_df)?;
|
|
let u_xcats = get_unique_xcats(&new_df)?;
|
|
let u_tickers = _get_unique_strs_from_str_column_object(&ticker_col)?;
|
|
|
|
let cids_vec = cids.unwrap_or_else(|| u_cids.clone());
|
|
let specified_cids: Vec<&str> = cids_vec.iter().map(AsRef::as_ref).collect();
|
|
let xcats_vec = xcats.unwrap_or_else(|| u_xcats.clone());
|
|
let specified_xcats: Vec<&str> = xcats_vec.iter().map(AsRef::as_ref).collect();
|
|
|
|
let non_idx_cols: Vec<String> = new_df
|
|
.get_column_names()
|
|
.iter()
|
|
.filter(|&col| !QDF_INDEX_COLUMNS.contains(&col.as_str()))
|
|
.map(|s| s.to_string())
|
|
.collect();
|
|
|
|
let specified_metrics: Vec<String> =
|
|
metrics.unwrap_or_else(|| non_idx_cols.iter().map(|s| s.to_string()).collect());
|
|
|
|
let specified_tickers: Vec<String> = create_intersecting_tickers(
|
|
&specified_cids
|
|
.iter()
|
|
.map(AsRef::as_ref)
|
|
.collect::<Vec<&str>>(),
|
|
&specified_xcats
|
|
.iter()
|
|
.map(AsRef::as_ref)
|
|
.collect::<Vec<&str>>(),
|
|
);
|
|
|
|
let keep_tickers: Vec<String> = match intersect {
|
|
// true => get_intersecting_cids_str_func(&specified_cids, &specified_xcats, &u_tickers),
|
|
true => {
|
|
let int_cids = get_intersecting_cids_str_func(
|
|
&specified_cids
|
|
.iter()
|
|
.map(|&s| s.to_string())
|
|
.collect::<Vec<String>>(),
|
|
&specified_xcats
|
|
.iter()
|
|
.map(|&s| s.to_string())
|
|
.collect::<Vec<String>>(),
|
|
&u_tickers,
|
|
);
|
|
create_intersecting_tickers(
|
|
&int_cids.iter().map(AsRef::as_ref).collect::<Vec<&str>>(),
|
|
&specified_xcats
|
|
.iter()
|
|
.map(AsRef::as_ref)
|
|
.collect::<Vec<&str>>(),
|
|
)
|
|
}
|
|
false => specified_tickers.clone(),
|
|
};
|
|
|
|
let kticks: Vec<&str> = keep_tickers
|
|
.iter()
|
|
.map(AsRef::as_ref)
|
|
.collect::<Vec<&str>>();
|
|
|
|
// Create a boolean mask to filter rows based on the tickers
|
|
let mut mask = vec![false; ticker_col.len()];
|
|
for (i, ticker) in ticker_col.str()?.iter().enumerate() {
|
|
if let Some(t) = ticker {
|
|
if kticks.contains(&t) {
|
|
mask[i] = true;
|
|
}
|
|
}
|
|
}
|
|
let mask = BooleanChunked::from_slice("mask".into(), &mask);
|
|
new_df = new_df.filter(&mask)?;
|
|
|
|
// Apply date filtering if `start` or `end` is provided
|
|
|
|
if let Some(start) = start {
|
|
let start_date = chrono::NaiveDate::parse_from_str(&start, "%Y-%m-%d")?;
|
|
new_df = new_df
|
|
.lazy()
|
|
.filter(
|
|
col("real_date")
|
|
.gt_eq(lit(start_date))
|
|
.alias("real_date")
|
|
.into(),
|
|
)
|
|
.collect()?;
|
|
}
|
|
|
|
if let Some(end) = end {
|
|
let end_date = chrono::NaiveDate::parse_from_str(&end, "%Y-%m-%d")?;
|
|
new_df = new_df
|
|
.lazy()
|
|
.filter(
|
|
col("real_date")
|
|
.lt_eq(lit(end_date))
|
|
.alias("real_date")
|
|
.into(),
|
|
)
|
|
.collect()?;
|
|
}
|
|
|
|
// Filter based on metrics if provided
|
|
assert!(specified_metrics.len() > 0);
|
|
|
|
// remove columns that are not in the specified metrics
|
|
let mut cols_to_remove = Vec::new();
|
|
for col in new_df.get_column_names() {
|
|
if !specified_metrics.contains(&col.to_string())
|
|
&& !QDF_INDEX_COLUMNS.contains(&col.as_str())
|
|
{
|
|
cols_to_remove.push(col);
|
|
}
|
|
}
|
|
new_df = new_df.drop_many(
|
|
cols_to_remove
|
|
.iter()
|
|
.map(|s| s.to_string())
|
|
.collect::<Vec<String>>(),
|
|
);
|
|
// check if the df is still the same size
|
|
let new_df_size = new_df.shape();
|
|
if df_size != new_df_size {
|
|
println!(
|
|
"Reduced DataFrame from {} to {} rows",
|
|
df_size.0, new_df_size.0
|
|
);
|
|
}
|
|
Ok(new_df)
|
|
}
|