use crate::utils::misc::{ _get_unique_strs_from_str_column_object, create_intersecting_tickers, get_intersecting_cids_str_func, get_unique_from_str_column, }; use polars::datatypes::DataType; use polars::prelude::*; use std::error::Error; /// Check if a DataFrame is a Quantamental DataFrame. /// A standard Quantamental DataFrame has the following columns: /// - `real_date`: Date column as a date type /// - `cid`: Column of cross-sectional identifiers /// - `xcat`: Column of extended categories /// /// Additionally, the DataFrame should have atleast 1 more column. /// Typically, this is one (or more) of the default JPMaQS metics. pub fn check_quantamental_dataframe(df: &DataFrame) -> Result<(), Box> { let expected_cols = ["real_date", "cid", "xcat"]; let expected_dtype = [DataType::Date, DataType::String, DataType::String]; let err = "Quantamental DataFrame must have at least 4 columns: 'real_date', 'cid', 'xcat' and one or more metrics."; for (col, dtype) in expected_cols.iter().zip(expected_dtype.iter()) { let col = df.column(col); if col.is_err() { return Err(format!("{} Column {:?} not found", err, col).into()); } let col = col?; if col.dtype() != dtype { return Err(format!("{} Column {:?} has wrong dtype", err, col).into()); } } Ok(()) } /// Check if a DataFrame is a Quantamental DataFrame. /// Returns true if the DataFrame is a Quantamental DataFrame, false otherwise. /// Uses the `check_quantamental_dataframe` function to check if the DataFrame is a Quantamental DataFrame. pub fn is_quantamental_dataframe(df: &DataFrame) -> bool { check_quantamental_dataframe(df).is_ok() } pub fn get_sorted_qdf_columns(columns: Vec) -> Vec { let index_columns = ["real_date", "cid", "xcat"]; let known_metrics = ["value", "grading", "eop_lag", "mop_lag"]; let mut unknown_metrics: Vec = columns .iter() .filter(|&m| !known_metrics.contains(&m.as_str())) .filter(|&m| !index_columns.contains(&m.as_str())) .cloned() .collect(); let mut new_columns: Vec = vec![]; new_columns.extend(index_columns.iter().map(|s| s.to_string())); for &colname in &known_metrics { if columns.contains(&colname.into()) { new_columns.push(colname.to_string()); } } unknown_metrics.sort(); new_columns.extend(unknown_metrics); new_columns } /// Sort the columns of a Quantamental DataFrame. /// The first columns are `real_date`, `cid`, and `xcat`. /// These are followed by any available JPMAQS metrics, 'value', 'grading', 'eop_lag', 'mop_lag', /// (**in that order**), followed by any other metrics (in alphabetical order). pub fn sort_qdf_columns(qdf: &mut DataFrame) -> Result<(), Box> { let df_columns = qdf .get_column_names() .into_iter() .map(|s| s.to_string()) .collect::>(); let new_columns = get_sorted_qdf_columns(df_columns); *qdf = qdf .select(new_columns.clone()) .expect("Failed to select columns"); Ok(()) } /// Get intersecting cross-sections from a Quantamental DataFrame. pub fn get_intersecting_cids( df: &DataFrame, xcats: &Option>, ) -> Result, Box> { let rel_xcats = xcats .clone() .unwrap_or_else(|| get_unique_xcats(df).unwrap()); let found_tickers = get_unique_tickers(df)?; let found_cids = get_unique_cids(df)?; let keep_cids = get_intersecting_cids_str_func(&found_cids, &rel_xcats, &found_tickers); Ok(keep_cids) } /// Get intersecting tickers from a Quantamental DataFrame. #[allow(dead_code)] fn get_tickers_interesecting_on_xcat( df: &DataFrame, xcats: &Option>, ) -> Result, Box> { let rel_cids = get_intersecting_cids(df, xcats)?; let rel_xcats = xcats .clone() .unwrap_or_else(|| get_unique_xcats(df).unwrap()); let rel_cids_str: Vec<&str> = rel_cids.iter().map(AsRef::as_ref).collect(); let rel_xcats_str: Vec<&str> = rel_xcats.iter().map(AsRef::as_ref).collect(); Ok(create_intersecting_tickers(&rel_cids_str, &rel_xcats_str)) } /// Get the unique tickers from a Quantamental DataFrame. pub fn get_ticker_column_for_quantamental_dataframe( df: &DataFrame, ) -> Result> { check_quantamental_dataframe(df)?; let mut ticker_df = DataFrame::new(vec![df.column("cid")?.clone(), df.column("xcat")?.clone()])? .lazy() .select([concat_str([col("cid"), col("xcat")], "_", true)]) .collect()?; Ok(ticker_df .rename("cid", "ticker".into()) .unwrap() .column("ticker") .unwrap() .clone()) } /// Get the unique tickers from a Quantamental DataFrame. /// Returns a Vec of unique tickers. pub fn get_unique_tickers(df: &DataFrame) -> Result, Box> { let ticker_col = get_ticker_column_for_quantamental_dataframe(df)?; _get_unique_strs_from_str_column_object(&ticker_col) } /// Get the unique cross-sectional identifiers (`cids`) from a Quantamental DataFrame. pub fn get_unique_cids(df: &DataFrame) -> Result, Box> { check_quantamental_dataframe(df)?; get_unique_from_str_column(df, "cid") } /// Get the unique extended categories (`xcats`) from a Quantamental DataFrame. pub fn get_unique_xcats(df: &DataFrame) -> Result, Box> { check_quantamental_dataframe(df)?; get_unique_from_str_column(df, "xcat") } pub fn get_unique_metrics(df: &DataFrame) -> Result, Box> { // return a list of all columns that are not 'real_date', 'cid', 'xcat' let columns = df .get_column_names() .iter() .map(|s| s.as_str().to_string()) .collect(); let sorted_cols = get_sorted_qdf_columns(columns); // return sorted_cols[3..].to_vec() Ok(sorted_cols[3..].to_vec()) } /// Get the unique dates as a polars Column from a Quantamental DataFrame. pub fn get_unique_dates(df: &DataFrame) -> Result> { let date_col = df.column("real_date")?; let unique_dates = date_col.unique()?.sort(SortOptions::default())?; Ok(unique_dates) }