From d4721a0b796adc1821bb61aa268fa767735f772f Mon Sep 17 00:00:00 2001 From: Palash Tyagi <23239946+Magnus167@users.noreply.github.com> Date: Sun, 17 Nov 2024 02:02:31 +0000 Subject: [PATCH] update module structure --- .gitignore | 2 +- src/download/helpers.rs | 4 +- src/download/parreq.rs | 2 +- src/lib.rs | 2 +- src/main.rs | 116 +++++++++- src/utils/dftools.rs | 443 ------------------------------------- src/utils/misc.rs | 28 ++- src/utils/mod.rs | 4 +- src/utils/qdf/core.rs | 143 ++++++++++++ src/utils/qdf/load.rs | 198 +++++++++++++++++ src/utils/qdf/mod.rs | 9 + src/utils/qdf/reduce_df.rs | 151 +++++++++++++ src/utils/qdf/update_df.rs | 35 +++ 13 files changed, 668 insertions(+), 469 deletions(-) delete mode 100644 src/utils/dftools.rs create mode 100644 src/utils/qdf/core.rs create mode 100644 src/utils/qdf/load.rs create mode 100644 src/utils/qdf/mod.rs create mode 100644 src/utils/qdf/reduce_df.rs create mode 100644 src/utils/qdf/update_df.rs diff --git a/.gitignore b/.gitignore index 6cde882..53a9aa8 100644 --- a/.gitignore +++ b/.gitignore @@ -5,7 +5,7 @@ dev/ *.pyc __pycache__/ *.log - +.idea/ /target data/ \ No newline at end of file diff --git a/src/download/helpers.rs b/src/download/helpers.rs index a4e0c79..064f189 100644 --- a/src/download/helpers.rs +++ b/src/download/helpers.rs @@ -402,7 +402,7 @@ fn timeseries_list_to_dataframe( timeseries_list: Vec, dropna: bool, ) -> Result> { - let mut output_df = DataFrame::new(vec![]).expect("Failed to create DataFrame"); + let mut output_df: DataFrame; if let Some((first, rest)) = timeseries_list.split_first() { // Convert the first timeseries to DataFrame and clone it to avoid modifying the original @@ -438,7 +438,7 @@ fn timeseries_list_to_dataframe( output_df = result_df.clone(); } else { - println!("No timeseries provided."); + return Err("No timeseries provided".into()); } // drop rows where all values are NA diff --git a/src/download/parreq.rs b/src/download/parreq.rs index 613c32b..cae9caf 100644 --- a/src/download/parreq.rs +++ b/src/download/parreq.rs @@ -1,5 +1,5 @@ -use crate::download::oauth_client::OAuthClient; use crate::download::helpers::DQTimeseriesRequestArgs; +use crate::download::oauth_client::OAuthClient; use futures::future; use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; use std::error::Error; diff --git a/src/lib.rs b/src/lib.rs index fa9cba6..6bcf0ea 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,4 @@ #![doc = include_str!("../README.md")] pub mod download; -pub mod utils; \ No newline at end of file +pub mod utils; diff --git a/src/main.rs b/src/main.rs index 1e9dd0d..efbeace 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,9 @@ use msyrs::download::jpmaqsdownload::{JPMaQSDownload, JPMaQSDownloadGetIndicatorArgs}; -use msyrs::utils::dftools as msyrs_dftools; +// use msyrs::utils::qdf::load::*; +// use msyrs::utils::qdf::dftools::*; +// use msyrs::utils::qdf::core::*; +use msyrs::utils::qdf::*; + #[allow(dead_code)] fn download_stuff() { @@ -42,7 +46,7 @@ fn download_stuff() { start.elapsed() ); - if !msyrs_dftools::is_quantamental_dataframe(&res_df) { + if !is_quantamental_dataframe(&res_df) { println!("DataFrame is not a quantamental DataFrame"); } else { println!("DataFrame is a quantamental DataFrame"); @@ -56,16 +60,110 @@ fn main() { // println!("{:?}", df); // load_quantamental_dataframe_from_download_bank // let st_pth = "E:/Work/ruzt/msyrs/data/JPMaQSData/"; - let st_pth = "E:/Work/jpmaqs-isc-git/jpmaqs-iscs"; - let df = msyrs_dftools::load_quantamental_dataframe_from_download_bank( + let start = std::time::Instant::now(); + let st_pth = "E:\\Work\\jpmaqs-data\\data"; + + let mega_df = load_quantamental_dataframe_from_download_bank( st_pth, - Some(vec!["AUD", "USD", "GBP", "JPY", "EUR", "CAD", "CHF", "INR", "CNY"]), - Some(vec!["EQXR_NSA", "FXXR_NSA", "RIR_NSA", "ALLIFCDSGDP_NSA"]), + // Some(vec!["AUD", "USD", "GBP", "JPY"]), + // Some(vec!["RIR_NSA", "EQXR_NSA"]), None, + None, + // Some(vec!["EQXR_NSA", "RIR_NSA"]), + // None + Some(vec![ + "AUD_EQXR_NSA", + "USD_EQXR_NSA", + "GBP_EQXR_NSA", + "JPY_EQXR_NSA", + "AUD_RIR_NSA", + "USD_RIR_NSA", + "GBP_RIR_NSA", + "JPY_RIR_NSA", + ]), ) - .unwrap(); - println!("{:?}", df); + .unwrap(); - // download_stuff(); + let end = start.elapsed(); + println!("Loaded Mega DataFrame in {:?}", end); + + let start = std::time::Instant::now(); + let df_new = reduce_dataframe( + mega_df.clone(), + Some(vec![ + "GBP".to_string(), + "AUD".to_string(), + "USD".to_string(), + ]), + Some(vec!["RIR_NSA".to_string(), "EQXR_NSA".to_string()]), + None, + Some("2010-01-20"), + None, + false, + ) + .unwrap(); + let end = start.elapsed(); + println!("Reduced Mega DataFrame in {:?}", end); + + // FOUND TICKERs + let start = std::time::Instant::now(); + let found_tickers = get_unique_tickers(&df_new); + let end = start.elapsed(); + println!( + "Found {:?} unique tickers in df_new", + found_tickers.unwrap() + ); + println!("Found unique tickers in {:?}", end); + + let end = start.elapsed(); + println!("Loaded DataFrame in {:?}", end); + let start = std::time::Instant::now(); + let df_gbp = reduce_dataframe( + df_new.clone(), + Some(vec!["GBP".to_string()]), + Some(vec!["RIR_NSA".to_string()]), + None, + Some("2024-11-12"), + None, + false, + ) + .unwrap(); + let end = start.elapsed(); + println!("Reduced DataFrame in {:?}", end); + // println!("{:?}", df_gbp.head(Some(10))); + + // FOUND TICKERs + let start = std::time::Instant::now(); + let found_tickers = get_unique_tickers(&mega_df); + let end = start.elapsed(); + println!( + "Found {:?} unique tickers in Mega DataFrame", + found_tickers.unwrap() + ); + println!("Found unique tickers in {:?}", end); + + let start = std::time::Instant::now(); + let df_aud = reduce_dataframe( + df_new.clone(), + Some(vec!["USD".to_string()]), + // Some(vec!["EQXR_NSA".to_string(), "RIR_NSA".to_string()]), + Some(vec!["EQXR_NSA".to_string()]), + None, + Some("2024-11-13"), + None, + true, + ) + .unwrap(); + let end = start.elapsed(); + println!("Reduced DataFrame in {:?}", end); + // dimenstions reduced from to + println!("{:?} from {:?}", df_aud.shape(), df_new.shape()); + // println!("{:?}", df_aud.head(Some(10))); + + let start = std::time::Instant::now(); + let up_df = update_dataframe(&df_gbp, &df_aud).unwrap(); + let end = start.elapsed(); + println!("Updated DataFrame in {:?}", end); + println!("{:?}", up_df.head(Some(10))); } diff --git a/src/utils/dftools.rs b/src/utils/dftools.rs deleted file mode 100644 index 1055091..0000000 --- a/src/utils/dftools.rs +++ /dev/null @@ -1,443 +0,0 @@ -use crate::utils::misc::*; -use anyhow; -use polars::datatypes::DataType; -use polars::prelude::*; -use rayon::prelude::*; -use std::error::Error; -use std::fs; - -/// The standard metrics provided by JPMaQS (`value`, `grading`, `eop_lag`, `mop_lag`). -pub const DEFAULT_JPMAQS_METRICS: [&str; 4] = ["value", "grading", "eop_lag", "mop_lag"]; - -/// The required columns for a Quantamental DataFrame. -pub const QDF_INDEX_COLUMNS: [&str; 3] = ["real_date", "cid", "xcat"]; - -/// Check if a DataFrame is a quantamental DataFrame. -/// A standard Quantamental DataFrame has the following columns: -/// - `real_date`: Date column as a date type -/// - `cid`: Column of cross-sectional identifiers -/// - `xcat`: Column of extended categories -/// -/// Additionally, the DataFrame should have atleast 1 more column. -/// Typically, this is one (or more) of the default JPMaQS metics. -pub fn check_quantamental_dataframe(df: &DataFrame) -> Result<(), Box> { - let expected_cols = ["real_date", "cid", "xcat"]; - let expected_dtype = [DataType::Date, DataType::String, DataType::String]; - for (col, dtype) in expected_cols.iter().zip(expected_dtype.iter()) { - let col = df.column(col); - if col.is_err() { - return Err(format!("Column {:?} not found", col).into()); - } - let col = col?; - if col.dtype() != dtype { - return Err(format!("Column {:?} has wrong dtype", col).into()); - } - } - Ok(()) -} - -/// Check if a DataFrame is a quantamental DataFrame. -/// Returns true if the DataFrame is a quantamental DataFrame, false otherwise. -/// Uses the `check_quantamental_dataframe` function to check if the DataFrame is a quantamental DataFrame. -pub fn is_quantamental_dataframe(df: &DataFrame) -> bool { - check_quantamental_dataframe(df).is_ok() -} - -pub fn sort_qdf_columns(qdf: &mut DataFrame) -> Result<(), Box> { - let index_columns = ["real_date", "cid", "xcat"]; - let known_metrics = ["value", "grading", "eop_lag", "mop_lag"]; - - let df_columns = qdf - .get_column_names() - .into_iter() - .map(|s| s.clone().into_string()) - .collect::>(); - - let mut unknown_metrics: Vec = df_columns - .iter() - .filter(|&m| !known_metrics.contains(&m.as_str())) - .filter(|&m| !index_columns.contains(&m.as_str())) - .cloned() - .collect(); - - let mut new_columns: Vec = vec![]; - new_columns.extend(index_columns.iter().map(|s| s.to_string())); - for &colname in &known_metrics { - if df_columns.contains(&colname.into()) { - new_columns.push(colname.to_string()); - } - } - - unknown_metrics.sort(); - new_columns.extend(unknown_metrics); - *qdf = qdf - .select(new_columns.clone()) - .expect("Failed to select columns"); - - Ok(()) -} - -fn _file_base_name(file_path: String) -> String { - std::path::Path::new(&file_path.clone()) - .file_stem() - .unwrap() - .to_str() - .unwrap() - .to_string() -} - -/// Load a Quantamental DataFrame from a CSV file. -/// The CSV must be named in the format `cid_xcat.csv` (`ticker.csv`). -/// The DataFrame must have a `real_date` column along with additional value columns. -pub fn load_quantamental_dataframe( - file_path: &str, -) -> Result> { - // get the file base name - let base_file_name = _file_base_name(file_path.into()); - - // if filename does not have _ then it is not a Quantamental DataFrame - if !base_file_name.contains('_') { - return Err("The file name must be in the format `cid_xcat.csv` (`ticker.csv`)".into()); - } - - let ticker = base_file_name.split('.').collect::>()[0]; - let (cid, xcat) = split_ticker(ticker.to_string())?; - - let mut df = CsvReadOptions::default() - .try_into_reader_with_file_path(Some(file_path.into())) - .unwrap() - .finish() - .unwrap(); - - let err = "The dataframe must have a `real_date` column and atleast 1 additional value column"; - if df.column("real_date").is_err() || df.width() < 2 { - return Err(err.into()); - } - - // check if the first item in the real_date column has a dash or not - let has_dashes = df - .column("real_date") - .unwrap() - .cast(&DataType::String)? - .get(0) - .unwrap() - .to_string() - .contains('-'); - - let date_format = if has_dashes { "%Y-%m-%d" } else { "%Y%m%d" }; - // let real_date_col = df - // .column("real_date".into()) - // .unwrap() - // .cast(&DataType::Date)?; - let real_date_col = df - .column("real_date") - .unwrap() - // .str()? - .cast(&DataType::String)? - .str()? - .as_date(Some(date_format), false) - .map_err(|e| format!("Failed to parse date with format {}: {}", date_format, e))?; - - df.with_column(real_date_col)?; - df.with_column(Series::new("cid".into(), vec![cid; df.height()]))?; - df.with_column(Series::new("xcat".into(), vec![xcat; df.height()]))?; - - sort_qdf_columns(&mut df)?; - - Ok(df) -} - -fn _load_qdf_thread_safe(file_path: &str) -> Result> { - let res = load_quantamental_dataframe(file_path); - res.map_err(|e| { - anyhow::Error::msg(e.to_string()) - .context("Failed to load quantamental dataframe") - .into() - }) -} -pub fn load_quantamental_dataframe_from_download_bank( - folder_path: &str, - cids: Option>, - xcats: Option>, - tickers: Option>, -) -> Result> { - let rcids = cids.unwrap_or_else(|| Vec::new()); - let rxcats = xcats.unwrap_or_else(|| Vec::new()); - let rtickers = tickers.unwrap_or_else(|| Vec::new()); - - // recursively read list of all files in the folder as a vector of strings - let files = fs::read_dir(folder_path)? - .map(|res| res.map(|e| e.path().display().to_string())) - .collect::, std::io::Error>>()?; - - // print number of files found - - // filter files that are not csv files - let files = files - .iter() - .filter(|f| f.ends_with(".csv")) - .collect::>(); - - // print number of csv files found - - let mut rel_files = Vec::new(); - for file in files { - let base_file_name: String = _file_base_name(file.into()) - .split('.') - .collect::>()[0] - .into(); - let (cid, xcat) = match split_ticker(base_file_name.clone()) { - Ok((cid, xcat)) => (cid, xcat), - Err(_) => continue, - }; - rel_files.push((file, cid, xcat)); - } - - // print number of relevant ticker files found - - let load_files = rel_files - .iter() - .filter(|(_, cid, xcat)| { - let f1 = rcids.len() > 0 && rcids.contains(&cid.as_str()); - let f2 = rxcats.len() > 0 && rxcats.contains(&xcat.as_str()); - let f3 = rtickers.len() > 0 && rtickers.contains(&create_ticker(cid, xcat).as_str()); - f1 | f2 | f3 - }) - .map(|(file, _, _)| *file) - .collect::>(); - - // print number of files to load - println!("Loading {} files", load_files.len()); - - if load_files.len() == 0 { - return Err("No files to load".into()); - } - if load_files.len() == 1 { - let dfx = load_quantamental_dataframe(load_files[0]).unwrap(); - return Ok(dfx); - } - - let load_files = load_files.iter().map(|s| s.as_str()).collect::>(); - let qdf_batches = load_files.chunks(500).collect::>(); - - let mut results = Vec::new(); - let mut curr_batch = 1; - let total_batches = qdf_batches.len(); - for batch in qdf_batches { - let qdf_list = batch - .par_iter() - .map(|file| _load_qdf_thread_safe(file).unwrap()) - .collect::>(); - results.extend(qdf_list); - curr_batch += 1; - } - - println!("Loaded {} files", results.len()); - - let mut res_df: DataFrame = results.pop().unwrap(); - while let Some(df) = results.pop() { - res_df = res_df.vstack(&df).unwrap(); - } - - Ok(res_df) -} - -/// Get intersecting cross-sections from a DataFrame. -pub fn get_intersecting_cids( - df: &DataFrame, - xcats: &Option>, -) -> Result, Box> { - let rel_xcats = xcats - .clone() - .unwrap_or_else(|| get_unique_xcats(df).unwrap()); - let found_tickers = get_unique_tickers(df)?; - let found_cids = get_unique_cids(df)?; - let keep_cids = get_intersecting_cids_str_func(&found_cids, &rel_xcats, &found_tickers); - Ok(keep_cids) -} - -/// Get intersecting tickers from a DataFrame. -#[allow(dead_code)] -fn get_tickers_interesecting_on_xcat( - df: &DataFrame, - xcats: &Option>, -) -> Result, Box> { - let rel_cids = get_intersecting_cids(df, xcats)?; - let rel_xcats = xcats - .clone() - .unwrap_or_else(|| get_unique_xcats(df).unwrap()); - let rel_cids_str: Vec<&str> = rel_cids.iter().map(AsRef::as_ref).collect(); - let rel_xcats_str: Vec<&str> = rel_xcats.iter().map(AsRef::as_ref).collect(); - Ok(create_interesecting_tickers(&rel_cids_str, &rel_xcats_str)) -} - -/// Get the unique tickers from a Quantamental DataFrame. -pub fn get_ticker_column_for_quantamental_dataframe( - df: &DataFrame, -) -> Result> { - check_quantamental_dataframe(df)?; - let mut ticker_df = - DataFrame::new(vec![df.column("cid")?.clone(), df.column("xcat")?.clone()])? - .lazy() - .select([concat_str([col("cid"), col("xcat")], "_", true)]) - .collect()?; - - Ok(ticker_df - .rename("cid", "ticker".into()) - .unwrap() - .column("ticker") - .unwrap() - .clone()) -} - -/// Get the unique tickers from a DataFrame. -/// Returns a Vec of unique tickers. -pub fn get_unique_tickers(df: &DataFrame) -> Result, Box> { - let ticker_col = get_ticker_column_for_quantamental_dataframe(df)?; - _get_unique_strs_from_str_column_object(&ticker_col) -} - -/// Get the unique cross-sectional identifiers (`cids`) from a DataFrame. -pub fn get_unique_cids(df: &DataFrame) -> Result, Box> { - check_quantamental_dataframe(df)?; - get_unique_from_str_column(df, "cid") -} - -/// Get the unique extended categories (`xcats`) from a DataFrame. -pub fn get_unique_xcats(df: &DataFrame) -> Result, Box> { - check_quantamental_dataframe(df)?; - get_unique_from_str_column(df, "xcat") -} - -/// Filter a dataframe based on the given parameters. -/// - `cids`: Filter by cross-sectional identifiers -/// - `xcats`: Filter by extended categories -/// - `metrics`: Filter by metrics -/// - `start`: Filter by start date -/// - `end`: Filter by end date -/// - `intersect`: If true, intersect only return `cids` that are present for all `xcats`. -/// Returns a new DataFrame with the filtered data, without modifying the original DataFrame. -/// If no filters are provided, the original DataFrame is returned. -pub fn reduce_dataframe( - df: &DataFrame, - cids: Option>, - xcats: Option>, - metrics: Option>, - start: Option<&str>, - end: Option<&str>, - intersect: bool, -) -> Result> { - check_quantamental_dataframe(df)?; - - let mut new_df: DataFrame = df.clone(); - - let ticker_col: Column = get_ticker_column_for_quantamental_dataframe(&new_df)?; - - // if cids is not provided, get all unique cids - let u_cids: Vec = get_unique_cids(&new_df)?; - let u_xcats: Vec = get_unique_xcats(&new_df)?; - let u_tickers: Vec = _get_unique_strs_from_str_column_object(&ticker_col)?; - - let specified_cids: Vec = cids.unwrap_or_else(|| u_cids.clone()); - let specified_xcats: Vec = xcats.unwrap_or_else(|| u_xcats.clone()); - let specified_metrics: Vec = metrics.unwrap_or_else(|| { - DEFAULT_JPMAQS_METRICS - .iter() - .map(|&s| s.to_string()) - .collect() - }); - let specified_tickers: Vec = create_interesecting_tickers( - &specified_cids - .iter() - .map(AsRef::as_ref) - .collect::>(), - &specified_xcats - .iter() - .map(AsRef::as_ref) - .collect::>(), - ); - - let keep_tickers: Vec = match intersect { - true => get_intersecting_cids_str_func(&u_cids, &u_xcats, &u_tickers), - false => specified_tickers.clone(), - }; - let kticks: Vec<&str> = keep_tickers - .iter() - .map(AsRef::as_ref) - .collect::>(); - - // Create a boolean mask to filter rows based on the tickers - let mut mask = vec![false; ticker_col.len()]; - for (i, ticker) in ticker_col.str()?.iter().enumerate() { - if let Some(t) = ticker { - if kticks.contains(&t) { - mask[i] = true; - } - } - } - let mask = BooleanChunked::from_slice("mask".into(), &mask); - new_df = new_df.filter(&mask)?; - - // Apply date filtering if `start` or `end` is provided - if let Some(start_date) = start { - new_df = new_df - .lazy() - .filter(col("real_date").gt_eq(start_date)) - .collect()?; - } - if let Some(end_date) = end { - new_df = new_df - .lazy() - .filter(col("real_date").lt_eq(end_date)) - .collect()?; - } - - // Filter based on metrics if provided - assert!(specified_metrics.len() > 0); - - // remove columns that are not in the specified metrics - let mut cols_to_remove = Vec::new(); - for col in new_df.get_column_names() { - if !specified_metrics.contains(&col.to_string()) { - cols_to_remove.push(col); - } - } - new_df = new_df.drop_many( - cols_to_remove - .iter() - .map(|s| s.to_string()) - .collect::>(), - ); - - Ok(new_df) -} - -/// Update a Quantamental DataFrame with new data. -/// - `df`: The original DataFrame -/// - `df_add`: The new DataFrame to add -/// -pub fn update_dataframe( - df: &DataFrame, - df_add: &DataFrame, - // xcat_replace: Option<&str>, -) -> Result> { - check_quantamental_dataframe(df)?; - check_quantamental_dataframe(df_add)?; - if df.is_empty() { - return Ok(df_add.clone()); - } else if df_add.is_empty() { - return Ok(df.clone()); - }; - - // vstack and drop duplicates keeping last - let mut new_df = df.vstack(df_add)?; - // help? - let idx_cols_vec = QDF_INDEX_COLUMNS - .iter() - .map(|s| s.to_string()) - .collect::>(); - - new_df = new_df.unique_stable(Some(&idx_cols_vec), UniqueKeepStrategy::Last, None)?; - - Ok(new_df) -} diff --git a/src/utils/misc.rs b/src/utils/misc.rs index 3e27133..c664708 100644 --- a/src/utils/misc.rs +++ b/src/utils/misc.rs @@ -63,21 +63,29 @@ pub fn get_intersecting_cids_str_func( xcats: &Vec, found_tickers: &Vec, ) -> Vec { - let mut keep_cids = cids.clone(); // make a hashmap of cids to xcats let mut cid_xcat_map = HashMap::new(); for ticker in found_tickers { let (cid, xcat) = split_ticker(ticker.clone()).unwrap(); - cid_xcat_map.insert(cid.to_string(), xcat.to_string()); - } - - // filter out cids that are not present in all xcats - for (cid, xcats_for_cid) in cid_xcat_map.iter() { - // if the all xcats are not present, remove the cid - if !xcats.iter().all(|xcat| xcats_for_cid.contains(xcat)) { - keep_cids.retain(|x| x != cid); + // if the cid is not in the map, add it + if !cid_xcat_map.contains_key(&cid) { + cid_xcat_map.insert(cid.clone(), vec![xcat.clone()]); + } else { + cid_xcat_map.get_mut(&cid).unwrap().push(xcat.clone()); } } - keep_cids + let mut found_cids: Vec = cid_xcat_map.keys().map(|x| x.clone()).collect(); + found_cids.retain(|x| cids.contains(x)); + + let mut new_keep_cids: Vec = Vec::new(); + for cid in found_cids { + let xcats_for_cid = cid_xcat_map.get(&cid).unwrap(); + let mut found_xcats: Vec = xcats_for_cid.iter().map(|x| x.clone()).collect(); + found_xcats.retain(|x| xcats.contains(x)); + if found_xcats.len() > 0 { + new_keep_cids.push(cid); + } + } + new_keep_cids } diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 7c7a0b2..39220a4 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,2 +1,2 @@ -pub mod dftools; -pub mod misc; \ No newline at end of file +pub mod qdf; +pub mod misc; diff --git a/src/utils/qdf/core.rs b/src/utils/qdf/core.rs new file mode 100644 index 0000000..1070deb --- /dev/null +++ b/src/utils/qdf/core.rs @@ -0,0 +1,143 @@ +use crate::utils::misc::{ + _get_unique_strs_from_str_column_object, create_interesecting_tickers, + get_intersecting_cids_str_func, get_unique_from_str_column, +}; +use polars::datatypes::DataType; +use polars::prelude::*; +use std::error::Error; + +/// Check if a DataFrame is a quantamental DataFrame. +/// A standard Quantamental DataFrame has the following columns: +/// - `real_date`: Date column as a date type +/// - `cid`: Column of cross-sectional identifiers +/// - `xcat`: Column of extended categories +/// +/// Additionally, the DataFrame should have atleast 1 more column. +/// Typically, this is one (or more) of the default JPMaQS metics. +pub fn check_quantamental_dataframe(df: &DataFrame) -> Result<(), Box> { + let expected_cols = ["real_date", "cid", "xcat"]; + let expected_dtype = [DataType::Date, DataType::String, DataType::String]; + for (col, dtype) in expected_cols.iter().zip(expected_dtype.iter()) { + let col = df.column(col); + if col.is_err() { + return Err(format!("Column {:?} not found", col).into()); + } + let col = col?; + if col.dtype() != dtype { + return Err(format!("Column {:?} has wrong dtype", col).into()); + } + } + Ok(()) +} + +/// Check if a DataFrame is a quantamental DataFrame. +/// Returns true if the DataFrame is a quantamental DataFrame, false otherwise. +/// Uses the `check_quantamental_dataframe` function to check if the DataFrame is a quantamental DataFrame. +pub fn is_quantamental_dataframe(df: &DataFrame) -> bool { + check_quantamental_dataframe(df).is_ok() +} + +/// Sort the columns of a Quantamental DataFrame. +/// The first columns are `real_date`, `cid`, and `xcat`. +/// These are followed by any available JPMAQS metrics, 'value', 'grading', 'eop_lag', 'mop_lag', +/// (**in that order**), followed by any other metrics (in alphabetical order). +pub fn sort_qdf_columns(qdf: &mut DataFrame) -> Result<(), Box> { + let index_columns = ["real_date", "cid", "xcat"]; + let known_metrics = ["value", "grading", "eop_lag", "mop_lag"]; + + let df_columns = qdf + .get_column_names() + .into_iter() + .map(|s| s.clone().into_string()) + .collect::>(); + + let mut unknown_metrics: Vec = df_columns + .iter() + .filter(|&m| !known_metrics.contains(&m.as_str())) + .filter(|&m| !index_columns.contains(&m.as_str())) + .cloned() + .collect(); + + let mut new_columns: Vec = vec![]; + new_columns.extend(index_columns.iter().map(|s| s.to_string())); + for &colname in &known_metrics { + if df_columns.contains(&colname.into()) { + new_columns.push(colname.to_string()); + } + } + + unknown_metrics.sort(); + new_columns.extend(unknown_metrics); + *qdf = qdf + .select(new_columns.clone()) + .expect("Failed to select columns"); + + Ok(()) +} + +/// Get intersecting cross-sections from a DataFrame. +pub fn get_intersecting_cids( + df: &DataFrame, + xcats: &Option>, +) -> Result, Box> { + let rel_xcats = xcats + .clone() + .unwrap_or_else(|| get_unique_xcats(df).unwrap()); + let found_tickers = get_unique_tickers(df)?; + let found_cids = get_unique_cids(df)?; + let keep_cids = get_intersecting_cids_str_func(&found_cids, &rel_xcats, &found_tickers); + Ok(keep_cids) +} + +/// Get intersecting tickers from a DataFrame. +#[allow(dead_code)] +fn get_tickers_interesecting_on_xcat( + df: &DataFrame, + xcats: &Option>, +) -> Result, Box> { + let rel_cids = get_intersecting_cids(df, xcats)?; + let rel_xcats = xcats + .clone() + .unwrap_or_else(|| get_unique_xcats(df).unwrap()); + let rel_cids_str: Vec<&str> = rel_cids.iter().map(AsRef::as_ref).collect(); + let rel_xcats_str: Vec<&str> = rel_xcats.iter().map(AsRef::as_ref).collect(); + Ok(create_interesecting_tickers(&rel_cids_str, &rel_xcats_str)) +} + +/// Get the unique tickers from a Quantamental DataFrame. +pub fn get_ticker_column_for_quantamental_dataframe( + df: &DataFrame, +) -> Result> { + check_quantamental_dataframe(df)?; + let mut ticker_df = + DataFrame::new(vec![df.column("cid")?.clone(), df.column("xcat")?.clone()])? + .lazy() + .select([concat_str([col("cid"), col("xcat")], "_", true)]) + .collect()?; + + Ok(ticker_df + .rename("cid", "ticker".into()) + .unwrap() + .column("ticker") + .unwrap() + .clone()) +} + +/// Get the unique tickers from a DataFrame. +/// Returns a Vec of unique tickers. +pub fn get_unique_tickers(df: &DataFrame) -> Result, Box> { + let ticker_col = get_ticker_column_for_quantamental_dataframe(df)?; + _get_unique_strs_from_str_column_object(&ticker_col) +} + +/// Get the unique cross-sectional identifiers (`cids`) from a DataFrame. +pub fn get_unique_cids(df: &DataFrame) -> Result, Box> { + check_quantamental_dataframe(df)?; + get_unique_from_str_column(df, "cid") +} + +/// Get the unique extended categories (`xcats`) from a DataFrame. +pub fn get_unique_xcats(df: &DataFrame) -> Result, Box> { + check_quantamental_dataframe(df)?; + get_unique_from_str_column(df, "xcat") +} diff --git a/src/utils/qdf/load.rs b/src/utils/qdf/load.rs new file mode 100644 index 0000000..3005367 --- /dev/null +++ b/src/utils/qdf/load.rs @@ -0,0 +1,198 @@ +use crate::utils::misc::{create_ticker, split_ticker}; +use crate::utils::qdf::*; +use anyhow; +use log; +use polars::datatypes::DataType; +use polars::prelude::*; +use rayon::prelude::*; +use std::error::Error; + +/// The number of concurrent file loads to perform. +const CONCURRENT_FILE_LOADS: usize = 500; + +fn _file_base_name(file_path: String) -> String { + std::path::Path::new(&file_path.clone()) + .file_stem() + .unwrap() + .to_str() + .unwrap() + .to_string() +} + +/// Load a Quantamental DataFrame from a CSV file. +/// The CSV must be named in the format `cid_xcat.csv` (`ticker.csv`). +/// The DataFrame must have a `real_date` column along with additional value columns. +pub fn load_quantamental_dataframe( + file_path: &str, +) -> Result> { + // get the file base name + let base_file_name = _file_base_name(file_path.into()); + + // if filename does not have _ then it is not a Quantamental DataFrame + if !base_file_name.contains('_') { + return Err("The file name must be in the format `cid_xcat.csv` (`ticker.csv`)".into()); + } + + let ticker = base_file_name.split('.').collect::>()[0]; + let (cid, xcat) = split_ticker(ticker.to_string())?; + + let mut df = CsvReadOptions::default() + .try_into_reader_with_file_path(Some(file_path.into())) + .unwrap() + .finish() + .unwrap(); + + let err = "The dataframe must have a `real_date` column and atleast 1 additional value column"; + if df.column("real_date").is_err() || df.width() < 2 { + return Err(err.into()); + } + + // check if the first item in the real_date column has a dash or not + let has_dashes = df + .column("real_date") + .unwrap() + .cast(&DataType::String)? + .get(0) + .unwrap() + .to_string() + .contains('-'); + + let date_format = if has_dashes { "%Y-%m-%d" } else { "%Y%m%d" }; + // let real_date_col = df + // .column("real_date".into()) + // .unwrap() + // .cast(&DataType::Date)?; + let real_date_col = df + .column("real_date") + .unwrap() + // .str()? + .cast(&DataType::String)? + .str()? + .as_date(Some(date_format), false) + .map_err(|e| format!("Failed to parse date with format {}: {}", date_format, e))?; + + df.with_column(real_date_col)?; + df.with_column(Series::new("cid".into(), vec![cid; df.height()]))?; + df.with_column(Series::new("xcat".into(), vec![xcat; df.height()]))?; + + sort_qdf_columns(&mut df)?; + + Ok(df) +} + +fn collect_paths_recursively>(path: P) -> std::io::Result> { + let mut paths = Vec::new(); + + for entry in std::fs::read_dir(path)? { + let entry = entry?; + let path = entry.path(); + + if path.is_dir() { + // Recurse into the directory and append the results + paths.extend(collect_paths_recursively(&path)?); + } + // Add the path to the vector + paths.push(path.to_string_lossy().to_string()); + } + + Ok(paths) +} + +fn _load_qdf_thread_safe(file_path: &str) -> Result> { + let res = load_quantamental_dataframe(file_path); + res.map_err(|e| { + anyhow::Error::msg(e.to_string()) + .context("Failed to load quantamental dataframe") + .into() + }) +} +pub fn load_quantamental_dataframe_from_download_bank( + folder_path: &str, + cids: Option>, + xcats: Option>, + tickers: Option>, +) -> Result> { + let rcids = cids.unwrap_or_else(|| Vec::new()); + let rxcats = xcats.unwrap_or_else(|| Vec::new()); + let rtickers = tickers.unwrap_or_else(|| Vec::new()); + + let files = collect_paths_recursively(folder_path)?; + log::info!("Found {} files", files.len()); + + // filter files that are not csv files + let files = files + .iter() + .filter(|f| f.ends_with(".csv")) + .collect::>(); + + log::info!("Found {} csv files", files.len()); + + let mut rel_files = Vec::new(); + for file in files { + let base_file_name: String = _file_base_name(file.into()) + .split('.') + .collect::>()[0] + .into(); + let (cid, xcat) = match split_ticker(base_file_name.clone()) { + Ok((cid, xcat)) => (cid, xcat), + Err(_) => continue, + }; + rel_files.push((file, cid, xcat)); + } + + log::info!("Found {} relevant ticker files", rel_files.len()); + + let load_files = rel_files + .iter() + .filter(|(_, cid, xcat)| { + let f1 = rcids.len() > 0 && rcids.contains(&cid.as_str()); + let f2 = rxcats.len() > 0 && rxcats.contains(&xcat.as_str()); + let f3 = rtickers.len() > 0 && rtickers.contains(&create_ticker(cid, xcat).as_str()); + f1 | f2 | f3 + }) + .map(|(file, _, _)| *file) + .collect::>(); + + // print number of files to load + log::info!("Loading {} files", load_files.len()); + + if load_files.len() == 0 { + return Err("No files to load".into()); + } + if load_files.len() == 1 { + let dfx = load_quantamental_dataframe(load_files[0]).unwrap(); + return Ok(dfx); + } + + let load_files = load_files.iter().map(|s| s.as_str()).collect::>(); + let qdf_batches = load_files + .chunks(CONCURRENT_FILE_LOADS) + .collect::>(); + + let mut results = Vec::new(); + let mut curr_batch = 1; + let total_batches = qdf_batches.len(); + for batch in qdf_batches { + let qdf_list = batch + .par_iter() + .map(|file| _load_qdf_thread_safe(file).unwrap().lazy()) + .collect::>(); + results.extend(qdf_list); + curr_batch += 1; + log::info!("Loaded {}/{} batches", curr_batch, total_batches); + } + + log::info!("Loaded {} files", results.len()); + let res_df = concat(results, UnionArgs::default()) + .unwrap() + .collect() + .unwrap(); + + log::info!( + "Loaded dataframe with {} rows and {} columns", + res_df.height(), + res_df.width() + ); + + Ok(res_df) +} diff --git a/src/utils/qdf/mod.rs b/src/utils/qdf/mod.rs new file mode 100644 index 0000000..cea06c3 --- /dev/null +++ b/src/utils/qdf/mod.rs @@ -0,0 +1,9 @@ +pub mod core; +pub mod update_df; +pub mod load; +pub mod reduce_df; +// Re-export submodules for easier access +pub use core::*; +pub use update_df::*; +pub use load::*; +pub use reduce_df::*; diff --git a/src/utils/qdf/reduce_df.rs b/src/utils/qdf/reduce_df.rs new file mode 100644 index 0000000..ab4d920 --- /dev/null +++ b/src/utils/qdf/reduce_df.rs @@ -0,0 +1,151 @@ +use crate::utils::misc::*; +use crate::utils::qdf::core::*; +use polars::prelude::*; +use std::error::Error; + +/// The required columns for a Quantamental DataFrame. +const QDF_INDEX_COLUMNS: [&str; 3] = ["real_date", "cid", "xcat"]; + +/// Filter a dataframe based on the given parameters. +/// - `cids`: Filter by cross-sectional identifiers +/// - `xcats`: Filter by extended categories +/// - `metrics`: Filter by metrics +/// - `start`: Filter by start date +/// - `end`: Filter by end date +/// - `intersect`: If true, intersect only return `cids` that are present for all `xcats`. +/// Returns a new DataFrame with the filtered data, without modifying the original DataFrame. +/// If no filters are provided, the original DataFrame is returned. +pub fn reduce_dataframe( + df: DataFrame, + cids: Option>, + xcats: Option>, + metrics: Option>, + start: Option<&str>, + end: Option<&str>, + intersect: bool, +) -> Result> { + check_quantamental_dataframe(&df)?; + // df_size + let df_size = df.shape(); + let mut new_df = df.clone(); + + let ticker_col: Column = get_ticker_column_for_quantamental_dataframe(&new_df)?; + + // if cids is not provided, get all unique cids + let u_cids: Vec = get_unique_cids(&new_df)?; + let u_xcats: Vec = get_unique_xcats(&new_df)?; + let u_tickers: Vec = _get_unique_strs_from_str_column_object(&ticker_col)?; + + let specified_cids: Vec = cids.unwrap_or_else(|| u_cids.clone()); + let specified_xcats: Vec = xcats.unwrap_or_else(|| u_xcats.clone()); + + let non_idx_cols: Vec = new_df + .get_column_names() + .iter() + .filter(|&col| !QDF_INDEX_COLUMNS.contains(&col.as_str())) + .map(|s| s.to_string()) + .collect(); + + let specified_metrics: Vec = + metrics.unwrap_or_else(|| non_idx_cols.iter().map(|s| s.to_string()).collect()); + + let specified_tickers: Vec = create_interesecting_tickers( + &specified_cids + .iter() + .map(AsRef::as_ref) + .collect::>(), + &specified_xcats + .iter() + .map(AsRef::as_ref) + .collect::>(), + ); + + let keep_tickers: Vec = match intersect { + // true => get_intersecting_cids_str_func(&specified_cids, &specified_xcats, &u_tickers), + true => { + let int_cids = + get_intersecting_cids_str_func(&specified_cids, &specified_xcats, &u_tickers); + create_interesecting_tickers( + &int_cids.iter().map(AsRef::as_ref).collect::>(), + &specified_xcats + .iter() + .map(AsRef::as_ref) + .collect::>(), + ) + } + false => specified_tickers.clone(), + }; + + let kticks: Vec<&str> = keep_tickers + .iter() + .map(AsRef::as_ref) + .collect::>(); + + // Create a boolean mask to filter rows based on the tickers + let mut mask = vec![false; ticker_col.len()]; + for (i, ticker) in ticker_col.str()?.iter().enumerate() { + if let Some(t) = ticker { + if kticks.contains(&t) { + mask[i] = true; + } + } + } + let mask = BooleanChunked::from_slice("mask".into(), &mask); + new_df = new_df.filter(&mask)?; + + // Apply date filtering if `start` or `end` is provided + + if let Some(start) = start { + let start_date = chrono::NaiveDate::parse_from_str(start, "%Y-%m-%d")?; + new_df = new_df + .lazy() + .filter( + col("real_date") + .gt_eq(lit(start_date)) + .alias("real_date") + .into(), + ) + .collect()?; + } + + if let Some(end) = end { + let end_date = chrono::NaiveDate::parse_from_str(end, "%Y-%m-%d")?; + new_df = new_df + .lazy() + .filter( + col("real_date") + .lt_eq(lit(end_date)) + .alias("real_date") + .into(), + ) + .collect()?; + } + + // Filter based on metrics if provided + assert!(specified_metrics.len() > 0); + + // remove columns that are not in the specified metrics + let mut cols_to_remove = Vec::new(); + for col in new_df.get_column_names() { + if !specified_metrics.contains(&col.to_string()) + && !QDF_INDEX_COLUMNS.contains(&col.as_str()) + { + cols_to_remove.push(col); + } + } + new_df = new_df.drop_many( + cols_to_remove + .iter() + .map(|s| s.to_string()) + .collect::>(), + ); + // check if the df is still the same size + let new_df_size = new_df.shape(); + if df_size != new_df_size { + println!( + "Reduced DataFrame from {} to {} rows", + df_size.0, new_df_size.0 + ); + } + Ok(new_df) +} diff --git a/src/utils/qdf/update_df.rs b/src/utils/qdf/update_df.rs new file mode 100644 index 0000000..bf24965 --- /dev/null +++ b/src/utils/qdf/update_df.rs @@ -0,0 +1,35 @@ +use crate::utils::qdf::core::*; +use polars::prelude::*; +use std::error::Error; + +const QDF_INDEX_COLUMNS: [&str; 3] = ["real_date", "cid", "xcat"]; + +/// Update a Quantamental DataFrame with new data. +/// - `df`: The original DataFrame +/// - `df_add`: The new DataFrame to add +/// +pub fn update_dataframe( + df: &DataFrame, + df_add: &DataFrame, + // xcat_replace: Option<&str>, +) -> Result> { + check_quantamental_dataframe(df)?; + check_quantamental_dataframe(df_add)?; + if df.is_empty() { + return Ok(df_add.clone()); + } else if df_add.is_empty() { + return Ok(df.clone()); + }; + + // vstack and drop duplicates keeping last + let mut new_df = df.vstack(df_add)?; + // help? + let idx_cols_vec = QDF_INDEX_COLUMNS + .iter() + .map(|s| s.to_string()) + .collect::>(); + + new_df = new_df.unique_stable(Some(&idx_cols_vec), UniqueKeepStrategy::Last, None)?; + + Ok(new_df) +}