use crate::utils::misc::*; use crate::utils::qdf::core::*; use polars::prelude::*; use std::error::Error; /// The required columns for a Quantamental DataFrame. const QDF_INDEX_COLUMNS: [&str; 3] = ["real_date", "cid", "xcat"]; /// Filter a dataframe based on the given parameters. /// - `cids`: Filter by cross-sectional identifiers /// - `xcats`: Filter by extended categories /// - `metrics`: Filter by metrics /// - `start`: Filter by start date /// - `end`: Filter by end date /// - `intersect`: If true, intersect only return `cids` that are present for all `xcats`. /// Returns a new DataFrame with the filtered data, without modifying the original DataFrame. /// If no filters are provided, the original DataFrame is returned. pub fn reduce_dataframe( df: DataFrame, cids: Option>, xcats: Option>, metrics: Option>, start: Option, end: Option, intersect: bool, ) -> Result> { check_quantamental_dataframe(&df)?; // df_size let df_size = df.shape(); let mut new_df = df.clone(); let ticker_col: Column = get_ticker_column_for_quantamental_dataframe(&new_df)?; // if cids is not provided, get all unique cids let u_cids: Vec = get_unique_cids(&new_df)?; let u_xcats: Vec = get_unique_xcats(&new_df)?; let u_tickers: Vec = _get_unique_strs_from_str_column_object(&ticker_col)?; let cids_vec = cids.unwrap_or_else(|| u_cids.clone()); let specified_cids: Vec<&str> = cids_vec.iter().map(AsRef::as_ref).collect(); let xcats_vec = xcats.unwrap_or_else(|| u_xcats.clone()); let specified_xcats: Vec<&str> = xcats_vec.iter().map(AsRef::as_ref).collect(); let non_idx_cols: Vec = new_df .get_column_names() .iter() .filter(|&col| !QDF_INDEX_COLUMNS.contains(&col.as_str())) .map(|s| s.to_string()) .collect(); let specified_metrics: Vec = metrics.unwrap_or_else(|| non_idx_cols.iter().map(|s| s.to_string()).collect()); let specified_tickers: Vec = create_interesecting_tickers( &specified_cids .iter() .map(AsRef::as_ref) .collect::>(), &specified_xcats .iter() .map(AsRef::as_ref) .collect::>(), ); let keep_tickers: Vec = match intersect { // true => get_intersecting_cids_str_func(&specified_cids, &specified_xcats, &u_tickers), true => { let int_cids = get_intersecting_cids_str_func( &specified_cids .iter() .map(|&s| s.to_string()) .collect::>(), &specified_xcats .iter() .map(|&s| s.to_string()) .collect::>(), &u_tickers, ); create_interesecting_tickers( &int_cids.iter().map(AsRef::as_ref).collect::>(), &specified_xcats .iter() .map(AsRef::as_ref) .collect::>(), ) } false => specified_tickers.clone(), }; let kticks: Vec<&str> = keep_tickers .iter() .map(AsRef::as_ref) .collect::>(); // Create a boolean mask to filter rows based on the tickers let mut mask = vec![false; ticker_col.len()]; for (i, ticker) in ticker_col.str()?.iter().enumerate() { if let Some(t) = ticker { if kticks.contains(&t) { mask[i] = true; } } } let mask = BooleanChunked::from_slice("mask".into(), &mask); new_df = new_df.filter(&mask)?; // Apply date filtering if `start` or `end` is provided if let Some(start) = start { let start_date = chrono::NaiveDate::parse_from_str(&start, "%Y-%m-%d")?; new_df = new_df .lazy() .filter( col("real_date") .gt_eq(lit(start_date)) .alias("real_date") .into(), ) .collect()?; } if let Some(end) = end { let end_date = chrono::NaiveDate::parse_from_str(&end, "%Y-%m-%d")?; new_df = new_df .lazy() .filter( col("real_date") .lt_eq(lit(end_date)) .alias("real_date") .into(), ) .collect()?; } // Filter based on metrics if provided assert!(specified_metrics.len() > 0); // remove columns that are not in the specified metrics let mut cols_to_remove = Vec::new(); for col in new_df.get_column_names() { if !specified_metrics.contains(&col.to_string()) && !QDF_INDEX_COLUMNS.contains(&col.as_str()) { cols_to_remove.push(col); } } new_df = new_df.drop_many( cols_to_remove .iter() .map(|s| s.to_string()) .collect::>(), ); // check if the df is still the same size let new_df_size = new_df.shape(); if df_size != new_df_size { println!( "Reduced DataFrame from {} to {} rows", df_size.0, new_df_size.0 ); } Ok(new_df) }