diff --git a/Cargo.lock b/Cargo.lock index f934dbd..5f9a623 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -425,6 +425,19 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crossbeam" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1137cd7e7fc0fb5d3c5a8678be38ec56e819125d8d7907411fe24ccb943faca8" +dependencies = [ + "crossbeam-channel", + "crossbeam-deque", + "crossbeam-epoch", + "crossbeam-queue", + "crossbeam-utils", +] + [[package]] name = "crossbeam-channel" version = "0.5.13" @@ -1216,6 +1229,7 @@ dependencies = [ name = "msyrs" version = "0.0.1" dependencies = [ + "crossbeam", "log", "polars", "rand", diff --git a/Cargo.toml b/Cargo.toml index 014100d..9b008c6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,3 +16,4 @@ rand = "0.8" threadpool = "1.8.1" log = "0.4.22" # dotenv = "0.15.0" +crossbeam = "0.8" diff --git a/src/download/jpmaqsdownload.rs b/src/download/jpmaqsdownload.rs index 1cc1ebd..239202a 100644 --- a/src/download/jpmaqsdownload.rs +++ b/src/download/jpmaqsdownload.rs @@ -7,18 +7,20 @@ use std::error::Error; const DEFAULT_JPMAQS_METRICS: [&str; 4] = ["value", "grading", "eop_lag", "mop_lag"]; -fn ticker_to_expressions(ticker: &str) -> Vec { - DEFAULT_JPMAQS_METRICS +fn ticker_to_expressions(ticker: &str, metrics: Vec<&str>) -> Vec { + metrics .iter() .map(|metric| format!("DB(JPMAQS,{},{})", ticker, metric)) .collect::>() } -fn construct_expressions(tickers: Vec) -> Vec { +fn construct_expressions(tickers: Vec, metrics: Vec) -> Vec { tickers .iter() - .flat_map(|ticker| ticker_to_expressions(ticker)) - .collect() + .flat_map(|ticker| { + ticker_to_expressions(ticker, metrics.clone().iter().map(|s| s.as_str()).collect()) + }) + .collect::>() } fn is_jpmaq_expression(expression: &str) -> bool { @@ -35,6 +37,28 @@ fn all_jpmaq_expressions(expressions: Vec) -> bool { .all(|expression| is_jpmaq_expression(expression)) } +#[derive(Debug, Clone)] +pub struct JPMaQSDownloadGetIndicatorArgs { + pub tickers: Vec, + pub metrics: Vec, + pub start_date: String, + pub end_date: String, +} + +impl Default for JPMaQSDownloadGetIndicatorArgs { + fn default() -> Self { + JPMaQSDownloadGetIndicatorArgs { + tickers: Vec::new(), + metrics: DEFAULT_JPMAQS_METRICS + .iter() + .map(|s| s.to_string()) + .collect(), + start_date: "1990-01-01".to_string(), + end_date: "TODAY+2D".to_string(), + } + } +} + #[derive(Debug, Clone)] pub struct JPMaQSDownload { requester: DQRequester, @@ -77,9 +101,13 @@ impl JPMaQSDownload { pub fn get_indicators( &mut self, - tickers: Vec, + download_args: JPMaQSDownloadGetIndicatorArgs, ) -> Result, Box> { - let expressions = construct_expressions(tickers); + if download_args.tickers.is_empty() { + return Err("No tickers provided".into()); + } + + let expressions = construct_expressions(download_args.tickers, download_args.metrics); assert!(all_jpmaq_expressions(expressions.clone())); let dqts_vec = self.get_expressions(expressions)?; diff --git a/src/download/requester.rs b/src/download/requester.rs index d871a39..a714004 100644 --- a/src/download/requester.rs +++ b/src/download/requester.rs @@ -9,6 +9,7 @@ use std::error::Error; use std::sync::{Arc, Mutex}; use std::thread; use std::time::Duration; +use crossbeam::channel; const API_BASE_URL: &str = "https://api-developer.jpmorgan.com/research/dataquery-authe/api/v2"; const HEARTBEAT_ENDPOINT: &str = "/services/heartbeat"; @@ -206,6 +207,38 @@ impl DQRequester { Ok(response) } + fn _fetch_expression_batch( + client: Arc>, + expr_batch: Vec, + okay_responses: Arc>>, + failed_batches: Arc>>>, + args: DQTimeseriesRequestArgs, + ) { + let response = client.lock().unwrap()._fetch_single_timeseries_batch(args); + + match response { + Ok(r) => { + // Attempt to parse the response text + match serde_json::from_str::(&r.text().unwrap()) { + Ok(dq_response) => { + okay_responses.lock().unwrap().push(dq_response); + log::info!("Got batch: {:?}", expr_batch); + } + Err(e) => { + // If parsing fails, treat this as a batch failure + failed_batches.lock().unwrap().push(expr_batch.clone()); + log::error!("Failed to parse timeseries: {:?} : {:?}", expr_batch, e); + } + } + } + Err(e) => { + // Handle _fetch_single_timeseries_batch error + failed_batches.lock().unwrap().push(expr_batch.clone()); + log::error!("Failed to get batch: {:?} : {:?}", expr_batch, e); + } + } + } + fn _fetch_timeseries_recursive( &mut self, args: DQTimeseriesRequestArgs, @@ -219,50 +252,26 @@ impl DQRequester { let okay_responses = Arc::new(Mutex::new(Vec::new())); let failed_batches = Arc::new(Mutex::new(Vec::new())); - let client = Arc::new(Mutex::new(self.clone())); let mut handles = vec![]; let total_batches = expression_batches.len(); let mut curr_batch = 0; + for expr_batch in expression_batches { curr_batch += 1; let mut args = args.clone(); args.update_expressions(expr_batch.clone()); + let okay_responses = Arc::clone(&okay_responses); let failed_batches = Arc::clone(&failed_batches); let client = Arc::clone(&client); - // if curr_batch mod 100 == 0 print progress + log::info!("Processed {} out of {} batches", curr_batch, total_batches); thread::sleep(Duration::from_millis(200)); - let handle = thread::spawn(move || { - let response = client.lock().unwrap()._fetch_single_timeseries_batch(args); - match response { - Ok(r) => { - // Attempt to parse the response text - match serde_json::from_str::(&r.text().unwrap()) { - Ok(dq_response) => { - okay_responses.lock().unwrap().push(dq_response); - log::info!("Got batch: {:?}", expr_batch); - } - Err(e) => { - // If parsing fails, treat this as a batch failure - failed_batches.lock().unwrap().push(expr_batch.clone()); - log::error!( - "Failed to parse timeseries: {:?} : {:?}", - expr_batch, - e - ); - } - } - } - Err(e) => { - // Handle _fetch_single_timeseries_batch error - failed_batches.lock().unwrap().push(expr_batch.clone()); - log::error!("Failed to get batch: {:?} : {:?}", expr_batch, e); - } - } + let handle = thread::spawn(move || { + DQRequester::_fetch_expression_batch(client, expr_batch, okay_responses, failed_batches, args); }); handles.push(handle); diff --git a/src/main.rs b/src/main.rs index fe3df3e..88596b2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,4 @@ -use msyrs::download::jpmaqsdownload::JPMaQSDownload; +use msyrs::download::jpmaqsdownload::{JPMaQSDownload, JPMaQSDownloadGetIndicatorArgs}; fn main() { println!("Authentication to DataQuery API"); @@ -16,7 +16,7 @@ fn main() { start.elapsed() ); - let num_ticks = 1000; + let num_ticks = 250; let sel_tickers: Vec = tickers .iter() .take(num_ticks) @@ -26,7 +26,12 @@ fn main() { println!("Retrieving indicators for {} tickers", sel_tickers.len()); start = std::time::Instant::now(); - let indicators = jpamqs_download.get_indicators(sel_tickers.clone()).unwrap(); + let indicators = jpamqs_download + .get_indicators(JPMaQSDownloadGetIndicatorArgs { + tickers: sel_tickers.clone(), + ..Default::default() + }) + .unwrap(); println!( "Retrieved indicators for {} tickers in {:?}",