msyrs/src/utils/qdf/core.rs

173 lines
6.2 KiB
Rust

use crate::utils::misc::{
_get_unique_strs_from_str_column_object, create_intersecting_tickers,
get_intersecting_cids_str_func, get_unique_from_str_column,
};
use polars::datatypes::DataType;
use polars::prelude::*;
use std::error::Error;
/// Check if a DataFrame is a Quantamental DataFrame.
/// A standard Quantamental DataFrame has the following columns:
/// - `real_date`: Date column as a date type
/// - `cid`: Column of cross-sectional identifiers
/// - `xcat`: Column of extended categories
///
/// Additionally, the DataFrame should have atleast 1 more column.
/// Typically, this is one (or more) of the default JPMaQS metics.
pub fn check_quantamental_dataframe(df: &DataFrame) -> Result<(), Box<dyn Error>> {
let expected_cols = ["real_date", "cid", "xcat"];
let expected_dtype = [DataType::Date, DataType::String, DataType::String];
let err = "Quantamental DataFrame must have at least 4 columns: 'real_date', 'cid', 'xcat' and one or more metrics.";
for (col, dtype) in expected_cols.iter().zip(expected_dtype.iter()) {
let col = df.column(col);
if col.is_err() {
return Err(format!("{} Column {:?} not found", err, col).into());
}
let col = col?;
if col.dtype() != dtype {
return Err(format!("{} Column {:?} has wrong dtype", err, col).into());
}
}
Ok(())
}
/// Check if a DataFrame is a Quantamental DataFrame.
/// Returns true if the DataFrame is a Quantamental DataFrame, false otherwise.
/// Uses the `check_quantamental_dataframe` function to check if the DataFrame is a Quantamental DataFrame.
pub fn is_quantamental_dataframe(df: &DataFrame) -> bool {
check_quantamental_dataframe(df).is_ok()
}
pub fn get_sorted_qdf_columns(columns: Vec<String>) -> Vec<String> {
let index_columns = ["real_date", "cid", "xcat"];
let known_metrics = ["value", "grading", "eop_lag", "mop_lag"];
let mut unknown_metrics: Vec<String> = columns
.iter()
.filter(|&m| !known_metrics.contains(&m.as_str()))
.filter(|&m| !index_columns.contains(&m.as_str()))
.cloned()
.collect();
let mut new_columns: Vec<String> = vec![];
new_columns.extend(index_columns.iter().map(|s| s.to_string()));
for &colname in &known_metrics {
if columns.contains(&colname.into()) {
new_columns.push(colname.to_string());
}
}
unknown_metrics.sort();
new_columns.extend(unknown_metrics);
new_columns
}
/// Sort the columns of a Quantamental DataFrame.
/// The first columns are `real_date`, `cid`, and `xcat`.
/// These are followed by any available JPMAQS metrics, 'value', 'grading', 'eop_lag', 'mop_lag',
/// (**in that order**), followed by any other metrics (in alphabetical order).
pub fn sort_qdf_columns(qdf: &mut DataFrame) -> Result<(), Box<dyn Error>> {
let df_columns = qdf
.get_column_names()
.into_iter()
.map(|s| s.to_string())
.collect::<Vec<String>>();
let new_columns = get_sorted_qdf_columns(df_columns);
*qdf = qdf
.select(new_columns.clone())
.expect("Failed to select columns");
Ok(())
}
/// Get intersecting cross-sections from a Quantamental DataFrame.
pub fn get_intersecting_cids(
df: &DataFrame,
xcats: &Option<Vec<String>>,
) -> Result<Vec<String>, Box<dyn Error>> {
let rel_xcats = xcats
.clone()
.unwrap_or_else(|| get_unique_xcats(df).unwrap());
let found_tickers = get_unique_tickers(df)?;
let found_cids = get_unique_cids(df)?;
let keep_cids = get_intersecting_cids_str_func(&found_cids, &rel_xcats, &found_tickers);
Ok(keep_cids)
}
/// Get intersecting tickers from a Quantamental DataFrame.
#[allow(dead_code)]
fn get_tickers_interesecting_on_xcat(
df: &DataFrame,
xcats: &Option<Vec<String>>,
) -> Result<Vec<String>, Box<dyn Error>> {
let rel_cids = get_intersecting_cids(df, xcats)?;
let rel_xcats = xcats
.clone()
.unwrap_or_else(|| get_unique_xcats(df).unwrap());
let rel_cids_str: Vec<&str> = rel_cids.iter().map(AsRef::as_ref).collect();
let rel_xcats_str: Vec<&str> = rel_xcats.iter().map(AsRef::as_ref).collect();
Ok(create_intersecting_tickers(&rel_cids_str, &rel_xcats_str))
}
/// Get the unique tickers from a Quantamental DataFrame.
pub fn get_ticker_column_for_quantamental_dataframe(
df: &DataFrame,
) -> Result<Column, Box<dyn Error>> {
check_quantamental_dataframe(df)?;
let mut ticker_df =
DataFrame::new(vec![df.column("cid")?.clone(), df.column("xcat")?.clone()])?
.lazy()
.select([concat_str([col("cid"), col("xcat")], "_", true)])
.collect()?;
Ok(ticker_df
.rename("cid", "ticker".into())
.unwrap()
.column("ticker")
.unwrap()
.clone())
}
/// Get the unique tickers from a Quantamental DataFrame.
/// Returns a Vec of unique tickers.
pub fn get_unique_tickers(df: &DataFrame) -> Result<Vec<String>, Box<dyn Error>> {
let ticker_col = get_ticker_column_for_quantamental_dataframe(df)?;
_get_unique_strs_from_str_column_object(&ticker_col)
}
/// Get the unique cross-sectional identifiers (`cids`) from a Quantamental DataFrame.
pub fn get_unique_cids(df: &DataFrame) -> Result<Vec<String>, Box<dyn Error>> {
check_quantamental_dataframe(df)?;
get_unique_from_str_column(df, "cid")
}
/// Get the unique extended categories (`xcats`) from a Quantamental DataFrame.
pub fn get_unique_xcats(df: &DataFrame) -> Result<Vec<String>, Box<dyn Error>> {
check_quantamental_dataframe(df)?;
get_unique_from_str_column(df, "xcat")
}
pub fn get_unique_metrics(df: &DataFrame) -> Result<Vec<String>, Box<dyn Error>> {
// return a list of all columns that are not 'real_date', 'cid', 'xcat'
let columns = df
.get_column_names()
.iter()
.map(|s| s.as_str().to_string())
.collect();
let sorted_cols = get_sorted_qdf_columns(columns);
// return sorted_cols[3..].to_vec()
Ok(sorted_cols[3..].to_vec())
}
/// Get the unique dates as a polars Column from a Quantamental DataFrame.
pub fn get_unique_dates(df: &DataFrame) -> Result<Column, Box<dyn Error>> {
let date_col = df.column("real_date")?;
let unique_dates = date_col.unique()?.sort(SortOptions::default())?;
Ok(unique_dates)
}