diff --git a/Cargo.lock b/Cargo.lock index 5f9a623..b82ff1e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1233,6 +1233,7 @@ dependencies = [ "log", "polars", "rand", + "rayon", "reqwest", "serde", "serde_json", diff --git a/Cargo.toml b/Cargo.toml index 9b008c6..151c4ac 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,12 +8,9 @@ reqwest = { version = "0.12.9", features = ["blocking", "json"] } serde_json = "1.0" serde_urlencoded = "0.7" serde = { version = "1.0", features = ["derive"] } -# polars = { version = "0.44.2", features = ["temporal", "lazy"] } polars = { version = "0.44.2", features = ["lazy"] } - -# anyhow = "1.0.92" rand = "0.8" threadpool = "1.8.1" log = "0.4.22" -# dotenv = "0.15.0" crossbeam = "0.8" +rayon = "1.5" diff --git a/src/download/jpmaqsdownload.rs b/src/download/jpmaqsdownload.rs index 239202a..77d6bff 100644 --- a/src/download/jpmaqsdownload.rs +++ b/src/download/jpmaqsdownload.rs @@ -3,6 +3,7 @@ use crate::download::requester::DQRequester; use crate::download::requester::DQTimeseriesRequestArgs; use crate::download::timeseries::DQTimeSeriesResponse; use crate::download::timeseries::JPMaQSIndicator; +use rayon::prelude::*; use std::error::Error; const DEFAULT_JPMAQS_METRICS: [&str; 4] = ["value", "grading", "eop_lag", "mop_lag"]; @@ -111,13 +112,31 @@ impl JPMaQSDownload { assert!(all_jpmaq_expressions(expressions.clone())); let dqts_vec = self.get_expressions(expressions)?; - println!("Retrieved {} time series", dqts_vec.len()); + // println!("Retrieved {} time series", -- sum[dqts_vec.iter().map(|dqts| dqts.len())]); + println!( + "Retrieved {} time series", + dqts_vec + .iter() + .map(|dqts| dqts.list_expressions().len()) + .sum::() + ); + let start = std::time::Instant::now(); - let indicators = dqts_vec - .iter() - .flat_map(|dqts| dqts.get_timeseries_by_ticker()) - .map(|tsv| JPMaQSIndicator::new(tsv)) - .collect::, Box>>()?; + + // let indicators = dqts_vec + // .iter() + // .flat_map(|dqts| dqts.get_timeseries_by_ticker()) + // .map(|tsv| JPMaQSIndicator::new(tsv)) + // .collect::, Box>>()?; + let indicators: Vec<_> = dqts_vec + .par_iter() + .flat_map(|dqts| { + dqts.get_timeseries_by_ticker() + .into_par_iter() + .filter_map(|tsv| JPMaQSIndicator::new(tsv).ok()) + }) + .collect(); + println!( "Converted time series to indicators in {:?}", start.elapsed() diff --git a/src/download/requester.rs b/src/download/requester.rs index a714004..9dbfe62 100644 --- a/src/download/requester.rs +++ b/src/download/requester.rs @@ -2,19 +2,23 @@ use crate::download::oauth_client::OAuthClient; use crate::download::timeseries::DQCatalogueResponse; use crate::download::timeseries::DQCatalogueSingleResponse; use crate::download::timeseries::DQTimeSeriesResponse; +use crossbeam::channel; use reqwest; use reqwest::blocking::Client; use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; +// use std::collections::HashMap; use std::error::Error; use std::sync::{Arc, Mutex}; use std::thread; use std::time::Duration; -use crossbeam::channel; const API_BASE_URL: &str = "https://api-developer.jpmorgan.com/research/dataquery-authe/api/v2"; const HEARTBEAT_ENDPOINT: &str = "/services/heartbeat"; const TIMESERIES_ENDPOINT: &str = "/expressions/time-series"; const CATALOGUE_ENDPOINT: &str = "/group/instruments"; +const API_DELAY_MILLIS: u64 = 200; + +const MAX_THREAD_WORKERS: usize = 20; // const JPMAQS_CATALOGUE_GROUP: &str = "JPMAQS"; #[derive(Debug, Clone)] @@ -162,7 +166,7 @@ impl DQRequester { )); while let Some(endpoint) = next_page { - std::thread::sleep(std::time::Duration::from_millis(200)); + std::thread::sleep(std::time::Duration::from_millis(API_DELAY_MILLIS)); let response = self._request(reqwest::Method::GET, &endpoint)?; if !response.status().is_success() { @@ -244,41 +248,64 @@ impl DQRequester { args: DQTimeseriesRequestArgs, max_retries: u32, ) -> Result, Box> { + // sort to ensure that the order of expressions is consistent let expression_batches: Vec> = args .expressions .chunks(20) - .map(|chunk| chunk.to_vec()) + .map(|chunk| { + let mut vec = chunk.to_vec(); + vec.sort(); + vec + }) .collect(); let okay_responses = Arc::new(Mutex::new(Vec::new())); let failed_batches = Arc::new(Mutex::new(Vec::new())); let client = Arc::new(Mutex::new(self.clone())); - let mut handles = vec![]; + let (sender, receiver) = channel::unbounded(); let total_batches = expression_batches.len(); let mut curr_batch = 0; + // Spawn 20 worker threads + let mut workers = vec![]; + for _ in 0..MAX_THREAD_WORKERS { + let receiver = receiver.clone(); + let okay_responses = Arc::clone(&okay_responses); + let failed_batches = Arc::clone(&failed_batches); + let client = Arc::clone(&client); + + let worker = thread::spawn(move || { + while let Ok((args, expr_batch)) = receiver.recv() { + DQRequester::_fetch_expression_batch( + client.clone(), + expr_batch, + okay_responses.clone(), + failed_batches.clone(), + args, + ); + } + }); + workers.push(worker); + } + + // Send jobs to workers for expr_batch in expression_batches { curr_batch += 1; let mut args = args.clone(); args.update_expressions(expr_batch.clone()); - let okay_responses = Arc::clone(&okay_responses); - let failed_batches = Arc::clone(&failed_batches); - let client = Arc::clone(&client); - log::info!("Processed {} out of {} batches", curr_batch, total_batches); - thread::sleep(Duration::from_millis(200)); + thread::sleep(Duration::from_millis(API_DELAY_MILLIS)); - let handle = thread::spawn(move || { - DQRequester::_fetch_expression_batch(client, expr_batch, okay_responses, failed_batches, args); - }); - - handles.push(handle); + sender.send((args, expr_batch)).unwrap(); } - for handle in handles { - handle.join().unwrap(); + drop(sender); // Close the channel so workers can finish + + // Wait for all workers to finish + for worker in workers { + worker.join().unwrap(); } let mut okay_responses = Arc::try_unwrap(okay_responses) diff --git a/src/main.rs b/src/main.rs index 88596b2..f86a4cd 100644 --- a/src/main.rs +++ b/src/main.rs @@ -16,7 +16,7 @@ fn main() { start.elapsed() ); - let num_ticks = 250; + let num_ticks = 1000; let sel_tickers: Vec = tickers .iter() .take(num_ticks)