diff --git a/src/download/jpmaqsdownload.rs b/src/download/jpmaqsdownload.rs index ced37eb..9d11fda 100644 --- a/src/download/jpmaqsdownload.rs +++ b/src/download/jpmaqsdownload.rs @@ -1,7 +1,8 @@ use crate::download::oauth_client::OAuthClient; use crate::download::requester::DQRequester; -use crate::download::timeseries::DQTimeseriesRequestArgs; +use crate::download::timeseries::DQTimeSeries; use crate::download::timeseries::DQTimeSeriesResponse; +use crate::download::timeseries::DQTimeseriesRequestArgs; use crate::download::timeseries::JPMaQSIndicator; use rayon::prelude::*; use std::error::Error; @@ -92,10 +93,13 @@ impl JPMaQSDownload { &mut self, expressions: Vec, ) -> Result, Box> { - let dqts_vec = self.requester.get_timeseries(DQTimeseriesRequestArgs { - expressions: expressions, - ..Default::default() - }).unwrap(); + let dqts_vec = self + .requester + .get_timeseries(DQTimeseriesRequestArgs { + expressions: expressions, + ..Default::default() + }) + .unwrap(); Ok(dqts_vec) } @@ -113,35 +117,50 @@ impl JPMaQSDownload { let dqts_vec = self.get_expressions(expressions)?; // println!("Retrieved {} time series", -- sum[dqts_vec.iter().map(|dqts| dqts.len())]); - println!( - "Retrieved {} time series", - dqts_vec - .iter() - .map(|dqts| dqts.list_expressions().len()) - .sum::() - ); - + println!("Retrieved all time series",); + // say pausing for 30 seconds + println!("Pausing for 10 seconds"); + std::thread::sleep(std::time::Duration::from_secs(10)); + println!("Resuming"); let start = std::time::Instant::now(); - // let indicators = dqts_vec - // .iter() - // .flat_map(|dqts| dqts.get_timeseries_by_ticker()) - // .map(|tsv| JPMaQSIndicator::new(tsv)) - // .collect::, Box>>()?; - let indicators: Vec<_> = dqts_vec - .into_par_iter() // Use into_par_iter() to parallelize and consume dqts_vec - .flat_map(|dqts| { - dqts.get_timeseries_by_ticker() - .into_par_iter() - .filter_map(|tsv| JPMaQSIndicator::new(tsv).ok()) - }) - .collect(); + let indicators = dq_response_vec_to_jpmaqs_indicators(dqts_vec); println!( "Converted time series to indicators in {:?}", start.elapsed() ); - + println!("Pausing for 10 seconds"); + std::thread::sleep(std::time::Duration::from_secs(10)); + println!("Resuming"); Ok(indicators) } } + +// fn dq_response_vec_to_jpmaqs_indicators( +// dqts_vec: Vec, +// ) -> Vec { +// let mut indicators: Vec = Vec::new(); +// for dqts in dqts_vec { +// indicators.extend( +// dqts.consume_to_grouped_by_ticker() // moves the values to free up memory +// .into_iter() +// .filter_map(|tsv| JPMaQSIndicator::new(tsv).ok()), +// ); +// } + +// indicators +// } +fn dq_response_vec_to_jpmaqs_indicators( + dqts_vec: Vec, +) -> Vec { + dqts_vec + .into_par_iter() + .flat_map(|dqts| { + dqts.consume_to_grouped_by_ticker() + .into_iter() + .filter_map(|tsv| JPMaQSIndicator::new(tsv).ok()) + .collect::>() + }) + .collect() +} diff --git a/src/download/parreq.rs b/src/download/parreq.rs index 8e9d1ef..0251a53 100644 --- a/src/download/parreq.rs +++ b/src/download/parreq.rs @@ -79,7 +79,7 @@ impl ParallelRequester { runtime.block_on(async { let mut tasks = vec![]; - let mut curr_batch = 0; + let mut curr_batch = 1; let total_batches = expression_batches.len(); for batch in expression_batches.into_iter() { let self_arc = Arc::clone(&self_arc); @@ -89,7 +89,10 @@ impl ParallelRequester { args_clone.update_expressions(batch.clone()); let ep = format!("{}?{}", TIMESERIES_ENDPOINT, args_clone.as_query_string()); let permit = semaphore.clone().acquire_owned().await.unwrap(); - println!("Requesting batch {} of {}", curr_batch + 1, total_batches); + // println!("Requesting batch {} of {}", curr_batch + 1, total_batches); + if curr_batch % 10 == 0 { + println!("Requesting batch {} of {}", curr_batch, total_batches); + } let task = task::spawn(async move { let _permit = permit; // Keep the permit alive until the end of the task let res_str = self_arc._request_async(&ep).await; diff --git a/src/download/requester.rs b/src/download/requester.rs index ece0b21..ee9d26b 100644 --- a/src/download/requester.rs +++ b/src/download/requester.rs @@ -4,21 +4,12 @@ use crate::download::timeseries::DQCatalogueResponse; use crate::download::timeseries::DQCatalogueSingleResponse; use crate::download::timeseries::DQTimeSeriesResponse; use crate::download::timeseries::DQTimeseriesRequestArgs; -use crossbeam::channel; -use rayon::iter::IntoParallelRefMutIterator; -use rayon::iter::ParallelDrainRange; -use rayon::iter::{ParallelIterator, IntoParallelIterator}; - -use rayon::prelude::*; -// use futures::TryFutureExt; +use crate::download::timeseries::JPMaQSIndicator; +use rayon::iter::{IntoParallelIterator, ParallelIterator}; use reqwest; use reqwest::blocking::Client; use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; -// use std::collections::HashMap; use std::error::Error; -use std::sync::{Arc, Mutex}; -use std::thread; -use std::time::Duration; const API_BASE_URL: &str = "https://api-developer.jpmorgan.com/research/dataquery-authe/api/v2"; const HEARTBEAT_ENDPOINT: &str = "/services/heartbeat"; @@ -26,7 +17,7 @@ const TIMESERIES_ENDPOINT: &str = "/expressions/time-series"; const CATALOGUE_ENDPOINT: &str = "/group/instruments"; const API_DELAY_MILLIS: u64 = 200; -const MAX_THREAD_WORKERS: usize = 100; +// const MAX_THREAD_WORKERS: usize = 100; // const JPMAQS_CATALOGUE_GROUP: &str = "JPMAQS"; #[derive(Debug, Clone)] @@ -161,159 +152,42 @@ impl DQRequester { args.expressions.len() ); let mut pq = ParallelRequester::new(self.oauth_client.clone()); - let pq_output = pq.request_expressions(args, max_retries); + let start = std::time::Instant::now(); - let response_texts = match pq_output { + let response_texts = match pq.request_expressions(args, max_retries) { Ok(r) => r, Err(e) => return Err(e), }; + println!( + "Time elapsed for pq.request_expressions: {:?}", + start.elapsed() + ); - // Parse the responses using rayon's parallel iterator - let dqts_vec: Vec = response_texts - .par_iter() - .filter_map(|rtext| serde_json::from_str::(rtext).ok()) - .collect(); + let dqts_vec: Vec = parse_response_texts(response_texts); + // Sleep for 10 seconds + println!("Pausing for 10 seconds"); + std::thread::sleep(std::time::Duration::from_secs(10)); + println!("Resuming"); Ok(dqts_vec) } } -// pub fn get_timeseries( -// &mut self, -// args: DQTimeseriesRequestArgs, -// ) -> Result, Box> { -// let max_retries = 5; -// println!( -// "Invoking recursive function for {:?} expressions", -// args.expressions.len() -// ); -// _fetch_timeseries_recursive(self, args, max_retries) -// } -// } -fn _fetch_expression_batch( - client: Arc>, - expr_batch: Vec, - okay_responses: Arc>>, - failed_batches: Arc>>>, - args: DQTimeseriesRequestArgs, -) { - let response = client.lock().unwrap()._fetch_single_timeseries_batch(args); - - match response { - Ok(r) => { - // Attempt to parse the response text - match serde_json::from_str::(&r.text().unwrap()) { - Ok(dq_response) => { - okay_responses.lock().unwrap().push(dq_response); - log::info!("Got batch: {:?}", expr_batch); - } - Err(e) => { - // If parsing fails, treat this as a batch failure - failed_batches.lock().unwrap().push(expr_batch.clone()); - log::error!("Failed to parse timeseries: {:?} : {:?}", expr_batch, e); +fn parse_response_texts(response_texts: Vec) -> Vec { + response_texts + .into_par_iter() + .filter_map(|rtext| { + // Attempt to deserialize and immediately release rtext if successful + match serde_json::from_str::(&rtext) { + Ok(dqts) => Some(dqts), + Err(err) => { + eprintln!("Failed to deserialize response: {}", err); + None } } - } - Err(e) => { - // Handle _fetch_single_timeseries_batch error - failed_batches.lock().unwrap().push(expr_batch.clone()); - log::error!("Failed to get batch: {:?} : {:?}", expr_batch, e); - } - } -} - -fn _fetch_timeseries_recursive( - dq_requester: &mut DQRequester, - args: DQTimeseriesRequestArgs, - max_retries: u32, -) -> Result, Box> { - // sort to ensure that the order of expressions is consistent - let expression_batches: Vec> = args - .expressions - .chunks(20) - .map(|chunk| { - let mut vec = chunk.to_vec(); - vec.sort(); - vec }) - .collect(); - - let okay_responses = Arc::new(Mutex::new(Vec::new())); - let failed_batches = Arc::new(Mutex::new(Vec::new())); - let client = Arc::new(Mutex::new(dq_requester.clone())); - - let (sender, receiver) = channel::unbounded(); - let total_batches = expression_batches.len(); - let mut curr_batch = 0; - - // Spawn 20 worker threads - let mut workers = vec![]; - for _ in 0..MAX_THREAD_WORKERS { - let receiver = receiver.clone(); - let okay_responses = Arc::clone(&okay_responses); - let failed_batches = Arc::clone(&failed_batches); - let client = Arc::clone(&client); - - let worker = thread::spawn(move || { - while let Ok((args, expr_batch)) = receiver.recv() { - _fetch_expression_batch( - client.clone(), - expr_batch, - okay_responses.clone(), - failed_batches.clone(), - args, - ); - } - }); - workers.push(worker); - } - - // Send jobs to workers - for expr_batch in expression_batches { - curr_batch += 1; - let mut args = args.clone(); - args.update_expressions(expr_batch.clone()); - - log::info!("Processed {} out of {} batches", curr_batch, total_batches); - thread::sleep(Duration::from_millis(API_DELAY_MILLIS)); - - sender.send((args, expr_batch)).unwrap(); - } - - drop(sender); // Close the channel so workers can finish - - // Wait for all workers to finish - for worker in workers { - worker.join().unwrap(); - } - - let mut okay_responses = Arc::try_unwrap(okay_responses) - .unwrap() - .into_inner() - .unwrap(); - let failed_batches = Arc::try_unwrap(failed_batches) - .unwrap() - .into_inner() - .unwrap(); - - if !failed_batches.is_empty() && max_retries == 0 { - return Err("Max retries reached".into()); - } - - if !failed_batches.is_empty() && max_retries > 0 { - log::info!("Retrying failed batches"); - let mut new_args = args.clone(); - new_args.update_expressions(failed_batches.concat()); - log::info!("Retrying with {} failed expressions", failed_batches.len()); - let mut retry_responses = - _fetch_timeseries_recursive(dq_requester, new_args, max_retries - 1)?; - okay_responses.append(&mut retry_responses); - } - - log::info!("Returning {} responses", okay_responses.len()); - Ok(okay_responses) + .collect() } - #[allow(dead_code)] fn main() { let client_id = std::env::var("DQ_CLIENT_ID").unwrap(); diff --git a/src/download/timeseries.rs b/src/download/timeseries.rs index abd7c99..fb0746c 100644 --- a/src/download/timeseries.rs +++ b/src/download/timeseries.rs @@ -8,9 +8,6 @@ use std::collections::HashSet; use std::error::Error; use std::fs::File; - - - #[derive(Debug, Clone)] pub struct DQTimeseriesRequestArgs { pub start_date: String, @@ -79,8 +76,6 @@ impl Default for DQTimeseriesRequestArgs { } } - - /// Response from the DataQuery API #[derive(Deserialize, Debug)] pub struct DQTimeSeriesResponse { @@ -144,6 +139,37 @@ struct Attribute { time_series: Vec<(String, Option)>, } +impl Attribute { + /// Get the ticker from the expression + pub fn get_ticker(&self) -> Result> { + if !self.expression.starts_with("DB(JPMAQS,") { + return Err("Expression does not start with 'DB(JPMAQS,'".into()); + } + let ticker = self.expression.split(',').nth(1).unwrap(); + if ticker.is_empty() { + return Err("Ticker is empty".into()); + } + Ok(ticker.to_string()) + } + + /// Get the metric from the expression + pub fn get_metric(&self) -> Result> { + if !self.expression.starts_with("DB(JPMAQS,") { + return Err("Expression does not start with 'DB(JPMAQS,'".into()); + } + let metric = self + .expression + .trim_end_matches(')') + .split(',') + .last() + .unwrap(); + if metric.is_empty() { + return Err("Metric is empty".into()); + } + Ok(metric.to_string()) + } +} + /// Representation of a single time series #[derive(Debug)] pub struct DQTimeSeries { @@ -253,6 +279,31 @@ impl DQTimeSeriesResponse { timeseries_by_ticker.into_iter().map(|(_, v)| v).collect() } + /// Consume the DQTimeSeriesResponse by grouping the time series by ticker. + /// This function can only be called once as it takes ownership of the data. + pub fn consume_to_grouped_by_ticker(mut self) -> Vec> { + // Take the instruments vector, leaving an empty one in its place. + let instruments = std::mem::take(&mut self.instruments); + + // Group time series by ticker + let mut timeseries_by_ticker: HashMap> = HashMap::new(); + + for instrument in instruments { + for attribute in instrument.attributes { + let ticker = attribute.get_ticker().unwrap_or_default(); + timeseries_by_ticker + .entry(ticker) + .or_default() + .push(DQTimeSeries { + expression: attribute.expression, + time_series: attribute.time_series, + }); + } + } + + // Convert the HashMap into a Vec of Vecs + timeseries_by_ticker.into_iter().map(|(_, v)| v).collect() + } } impl JPMaQSIndicator { diff --git a/src/main.rs b/src/main.rs index 17b2fab..ca85d91 100644 --- a/src/main.rs +++ b/src/main.rs @@ -20,7 +20,7 @@ fn main() { start.elapsed() ); - let num_ticks = 2500; + let num_ticks = 1000; let sel_tickers: Vec = tickers .iter() .take(num_ticks)