diff --git a/Cargo.lock b/Cargo.lock index 26caa58..d60b68c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -75,6 +75,12 @@ dependencies = [ "libc", ] +[[package]] +name = "anyhow" +version = "1.0.93" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c95c10ba0b00a02636238b814946408b1322d5ac4760326e6fb8ec956d85775" + [[package]] name = "argminmax" version = "0.6.2" @@ -1398,6 +1404,7 @@ dependencies = [ name = "msyrs" version = "0.0.1" dependencies = [ + "anyhow", "chrono", "crossbeam", "futures", diff --git a/Cargo.toml b/Cargo.toml index 2e8ed48..e48f8b0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,7 @@ name = "msyrs" path = "src/lib.rs" [dependencies] +anyhow = "1.0" reqwest = { version = "0.12.9", features = ["blocking", "json"] } serde_json = "1.0" serde_urlencoded = "0.7" diff --git a/src/main.rs b/src/main.rs index f6510e3..1e9dd0d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -51,7 +51,21 @@ fn download_stuff() { fn main() { // E:\Work\ruzt\msyrs\data\JPMaQSData\ALLIFCDSGDP\AUD_ALLIFCDSGDP_NSA.csv - let pth = "E:/Work/ruzt/msyrs/data/JPMaQSData/ALLIFCDSGDP/AUD_ALLIFCDSGDP_NSA.csv"; - let df = msyrs_dftools::load_quantamental_dataframe(pth).unwrap(); + // let pth = "E:/Work/ruzt/msyrs/data/JPMaQSData/ALLIFCDSGDP/AUD_ALLIFCDSGDP_NSA.csv"; + // let df = msyrs_dftools::load_quantamental_dataframe(pth).unwrap(); + // println!("{:?}", df); + // load_quantamental_dataframe_from_download_bank + // let st_pth = "E:/Work/ruzt/msyrs/data/JPMaQSData/"; + let st_pth = "E:/Work/jpmaqs-isc-git/jpmaqs-iscs"; + + let df = msyrs_dftools::load_quantamental_dataframe_from_download_bank( + st_pth, + Some(vec!["AUD", "USD", "GBP", "JPY", "EUR", "CAD", "CHF", "INR", "CNY"]), + Some(vec!["EQXR_NSA", "FXXR_NSA", "RIR_NSA", "ALLIFCDSGDP_NSA"]), + None, + ) + .unwrap(); println!("{:?}", df); + + // download_stuff(); } diff --git a/src/utils/dftools.rs b/src/utils/dftools.rs index c7ca3f5..1055091 100644 --- a/src/utils/dftools.rs +++ b/src/utils/dftools.rs @@ -1,7 +1,10 @@ use crate::utils::misc::*; +use anyhow; use polars::datatypes::DataType; use polars::prelude::*; +use rayon::prelude::*; use std::error::Error; +use std::fs; /// The standard metrics provided by JPMaQS (`value`, `grading`, `eop_lag`, `mop_lag`). pub const DEFAULT_JPMAQS_METRICS: [&str; 4] = ["value", "grading", "eop_lag", "mop_lag"]; @@ -74,24 +77,31 @@ pub fn sort_qdf_columns(qdf: &mut DataFrame) -> Result<(), Box> { Ok(()) } -pub fn load_quantamental_dataframe( - file_path: &str, -) -> Result> { - // get the file base name - let file_name = std::path::Path::new(file_path) +fn _file_base_name(file_path: String) -> String { + std::path::Path::new(&file_path.clone()) .file_stem() .unwrap() .to_str() .unwrap() - .to_string(); + .to_string() +} + +/// Load a Quantamental DataFrame from a CSV file. +/// The CSV must be named in the format `cid_xcat.csv` (`ticker.csv`). +/// The DataFrame must have a `real_date` column along with additional value columns. +pub fn load_quantamental_dataframe( + file_path: &str, +) -> Result> { + // get the file base name + let base_file_name = _file_base_name(file_path.into()); // if filename does not have _ then it is not a Quantamental DataFrame - if !file_name.contains('_') { + if !base_file_name.contains('_') { return Err("The file name must be in the format `cid_xcat.csv` (`ticker.csv`)".into()); } - let ticker = file_name.split('.').collect::>()[0]; - let (cid, xcat) = split_ticker(ticker)?; + let ticker = base_file_name.split('.').collect::>()[0]; + let (cid, xcat) = split_ticker(ticker.to_string())?; let mut df = CsvReadOptions::default() .try_into_reader_with_file_path(Some(file_path.into())) @@ -103,10 +113,30 @@ pub fn load_quantamental_dataframe( if df.column("real_date").is_err() || df.width() < 2 { return Err(err.into()); } - let real_date_col = df - .column("real_date".into()) + + // check if the first item in the real_date column has a dash or not + let has_dashes = df + .column("real_date") .unwrap() - .cast(&DataType::Date)?; + .cast(&DataType::String)? + .get(0) + .unwrap() + .to_string() + .contains('-'); + + let date_format = if has_dashes { "%Y-%m-%d" } else { "%Y%m%d" }; + // let real_date_col = df + // .column("real_date".into()) + // .unwrap() + // .cast(&DataType::Date)?; + let real_date_col = df + .column("real_date") + .unwrap() + // .str()? + .cast(&DataType::String)? + .str()? + .as_date(Some(date_format), false) + .map_err(|e| format!("Failed to parse date with format {}: {}", date_format, e))?; df.with_column(real_date_col)?; df.with_column(Series::new("cid".into(), vec![cid; df.height()]))?; @@ -117,6 +147,101 @@ pub fn load_quantamental_dataframe( Ok(df) } +fn _load_qdf_thread_safe(file_path: &str) -> Result> { + let res = load_quantamental_dataframe(file_path); + res.map_err(|e| { + anyhow::Error::msg(e.to_string()) + .context("Failed to load quantamental dataframe") + .into() + }) +} +pub fn load_quantamental_dataframe_from_download_bank( + folder_path: &str, + cids: Option>, + xcats: Option>, + tickers: Option>, +) -> Result> { + let rcids = cids.unwrap_or_else(|| Vec::new()); + let rxcats = xcats.unwrap_or_else(|| Vec::new()); + let rtickers = tickers.unwrap_or_else(|| Vec::new()); + + // recursively read list of all files in the folder as a vector of strings + let files = fs::read_dir(folder_path)? + .map(|res| res.map(|e| e.path().display().to_string())) + .collect::, std::io::Error>>()?; + + // print number of files found + + // filter files that are not csv files + let files = files + .iter() + .filter(|f| f.ends_with(".csv")) + .collect::>(); + + // print number of csv files found + + let mut rel_files = Vec::new(); + for file in files { + let base_file_name: String = _file_base_name(file.into()) + .split('.') + .collect::>()[0] + .into(); + let (cid, xcat) = match split_ticker(base_file_name.clone()) { + Ok((cid, xcat)) => (cid, xcat), + Err(_) => continue, + }; + rel_files.push((file, cid, xcat)); + } + + // print number of relevant ticker files found + + let load_files = rel_files + .iter() + .filter(|(_, cid, xcat)| { + let f1 = rcids.len() > 0 && rcids.contains(&cid.as_str()); + let f2 = rxcats.len() > 0 && rxcats.contains(&xcat.as_str()); + let f3 = rtickers.len() > 0 && rtickers.contains(&create_ticker(cid, xcat).as_str()); + f1 | f2 | f3 + }) + .map(|(file, _, _)| *file) + .collect::>(); + + // print number of files to load + println!("Loading {} files", load_files.len()); + + if load_files.len() == 0 { + return Err("No files to load".into()); + } + if load_files.len() == 1 { + let dfx = load_quantamental_dataframe(load_files[0]).unwrap(); + return Ok(dfx); + } + + let load_files = load_files.iter().map(|s| s.as_str()).collect::>(); + let qdf_batches = load_files.chunks(500).collect::>(); + + let mut results = Vec::new(); + let mut curr_batch = 1; + let total_batches = qdf_batches.len(); + for batch in qdf_batches { + let qdf_list = batch + .par_iter() + .map(|file| _load_qdf_thread_safe(file).unwrap()) + .collect::>(); + results.extend(qdf_list); + curr_batch += 1; + } + + println!("Loaded {} files", results.len()); + + let mut res_df: DataFrame = results.pop().unwrap(); + while let Some(df) = results.pop() { + res_df = res_df.vstack(&df).unwrap(); + } + + Ok(res_df) +} + /// Get intersecting cross-sections from a DataFrame. pub fn get_intersecting_cids( df: &DataFrame, diff --git a/src/utils/misc.rs b/src/utils/misc.rs index d1491a3..3e27133 100644 --- a/src/utils/misc.rs +++ b/src/utils/misc.rs @@ -2,22 +2,22 @@ use polars::prelude::*; use std::collections::HashMap; use std::error::Error; -pub fn split_ticker(ticker: &str) -> Result<(&str, &str), Box> { +pub fn split_ticker(ticker: String) -> Result<(String, String), Box> { // split by the first underscore character. return the first and second parts. let parts: Vec<&str> = ticker.splitn(2, '_').collect(); if parts.len() != 2 { return Err("Invalid ticker format".into()); } - Ok((parts[0], parts[1])) + Ok((parts[0].to_string(), parts[1].to_string())) } #[allow(dead_code)] -pub fn get_cid(ticker: &str) -> Result<&str, Box> { +pub fn get_cid(ticker: String) -> Result> { split_ticker(ticker).map(|(cid, _)| cid) } #[allow(dead_code)] -pub fn get_xcat(ticker: &str) -> Result<&str, Box> { +pub fn get_xcat(ticker: String) -> Result> { split_ticker(ticker).map(|(_, xcat)| xcat) } @@ -67,7 +67,7 @@ pub fn get_intersecting_cids_str_func( // make a hashmap of cids to xcats let mut cid_xcat_map = HashMap::new(); for ticker in found_tickers { - let (cid, xcat) = split_ticker(&ticker).unwrap(); + let (cid, xcat) = split_ticker(ticker.clone()).unwrap(); cid_xcat_map.insert(cid.to_string(), xcat.to_string()); }