diff --git a/src/download/oauth_client.rs b/src/download/oauth_client.rs index f71ebfd..8c39d7d 100644 --- a/src/download/oauth_client.rs +++ b/src/download/oauth_client.rs @@ -30,7 +30,7 @@ impl OAuthClient { } } - pub fn fetch_token(&mut self) -> Result<(), ReqwestError> { + fn fetch_token(&mut self) -> Result<(), ReqwestError> { let client = Client::new(); // Set up the form parameters for the request diff --git a/src/download/parreq.rs b/src/download/parreq.rs index 0251a53..7282089 100644 --- a/src/download/parreq.rs +++ b/src/download/parreq.rs @@ -12,20 +12,26 @@ use tokio::task; const API_BASE_URL: &str = "https://api-developer.jpmorgan.com/research/dataquery-authe/api/v2"; const TIMESERIES_ENDPOINT: &str = "/expressions/time-series"; +const DQ_THREAD_EXPIRY_SECONDS: u64 = 300; +const DQ_THREAD_SLEEP_MILLIS: u64 = 200; +const DQ_MAX_THREADS: usize = 250; + #[derive(Clone)] pub struct ParallelRequester { oauth_client: Arc>, max_threads: usize, } +/// Internal implementation of a struct for making parallel requests to the DataQuery API. impl ParallelRequester { pub fn new(oauth_client: OAuthClient) -> Self { ParallelRequester { oauth_client: Arc::new(Mutex::new(oauth_client)), - max_threads: 250, + max_threads: DQ_MAX_THREADS, } } + /// Builds the headers for the request to the DataQuery API. fn build_headers(&self) -> Result> { let headers_map = self.oauth_client.lock().unwrap().get_headers()?; let mut headers = HeaderMap::new(); @@ -37,6 +43,7 @@ impl ParallelRequester { Ok(headers) } + /// Makes an asynchronous request to the DataQuery API. async fn _request_async(&self, endpoint: &str) -> Result> { let headers = self.build_headers()?; let url = format!("{}{}", API_BASE_URL, endpoint); @@ -55,7 +62,7 @@ impl ParallelRequester { } } - pub fn request_expressions( + fn _request_expressions( &mut self, args: DQTimeseriesRequestArgs, max_retries: u32, @@ -89,20 +96,25 @@ impl ParallelRequester { args_clone.update_expressions(batch.clone()); let ep = format!("{}?{}", TIMESERIES_ENDPOINT, args_clone.as_query_string()); let permit = semaphore.clone().acquire_owned().await.unwrap(); - // println!("Requesting batch {} of {}", curr_batch + 1, total_batches); + if curr_batch % 10 == 0 { println!("Requesting batch {} of {}", curr_batch, total_batches); } + let task = task::spawn(async move { let _permit = permit; // Keep the permit alive until the end of the task - let res_str = self_arc._request_async(&ep).await; + let res_future = self_arc._request_async(&ep); - match res_str { - Ok(text) => { + let res = tokio::time::timeout( + Duration::from_secs(DQ_THREAD_EXPIRY_SECONDS), + res_future, + ) + .await; + match res { + Ok(Ok(text)) => { okay_response_texts.lock().unwrap().push(text); - println!("Batch {} of {} successful", curr_batch, total_batches); } - Err(_) => { + _ => { failed_batches.lock().unwrap().push(batch); } } @@ -111,11 +123,16 @@ impl ParallelRequester { tasks.push(task); // Delay before starting the next task - tokio::time::sleep(Duration::from_millis(250)).await; + tokio::time::sleep(Duration::from_millis(DQ_THREAD_SLEEP_MILLIS)).await; curr_batch += 1; } - future::join_all(tasks).await; + // Await all tasks, ensuring forced termination for any remaining tasks + let _ = tokio::time::timeout( + Duration::from_secs(DQ_THREAD_EXPIRY_SECONDS), + future::join_all(tasks), + ) + .await; }); // Retry failed batches if any, respecting the max_retries limit @@ -126,7 +143,7 @@ impl ParallelRequester { failed_batches.into_iter().flatten().collect(); new_args.update_expressions(flattened_failed_batches); - let retried_responses = self.request_expressions(new_args, max_retries - 1)?; + let retried_responses = self._request_expressions(new_args, max_retries - 1)?; okay_response_texts .lock() .unwrap() @@ -139,4 +156,14 @@ impl ParallelRequester { let final_responses = okay_response_texts.lock().unwrap().clone(); Ok(final_responses) } + + /// Makes parallel requests to the DataQuery API for the given expressions. + /// The returned responses are collected and returned as a vector of strings (response texts). + pub fn request_expressions( + &mut self, + args: DQTimeseriesRequestArgs, + max_retries: u32, + ) -> Result, Box> { + self._request_expressions(args, max_retries) + } } diff --git a/src/download/requester.rs b/src/download/requester.rs index ea1fc7e..fcec710 100644 --- a/src/download/requester.rs +++ b/src/download/requester.rs @@ -43,6 +43,7 @@ impl DQRequester { } } + /// Internal implementation of a function to make a request to the DataQuery API. fn _request( &mut self, method: reqwest::Method, @@ -79,11 +80,12 @@ impl DQRequester { Ok(()) } + /// Fetches the catalogue of tickers from the DataQuery API. + /// Returned DQCatalogueResponse contains a list of DQCatalogueSingleResponse objects. pub fn get_catalogue( &mut self, catalogue_group: &str, page_size: u32, - // ) -> Result, Box> { ) -> Result> { let mut responses: Vec = Vec::new(); @@ -260,40 +262,3 @@ fn parse_response_texts_to_jpmaqs_indicators( jpmaqs_indicators_map.into_iter().map(|(_, v)| v).collect() } -#[allow(dead_code)] -fn main() { - let client_id = std::env::var("DQ_CLIENT_ID").unwrap(); - let client_secret = std::env::var("DQ_CLIENT_SECRET").unwrap(); - - let mut oauth_client = OAuthClient::new(client_id, client_secret); - oauth_client.fetch_token().unwrap(); - - let mut requester = DQRequester::new(oauth_client); - requester.check_connection().unwrap(); - - // let response = requester - // .get_catalogue(JPMAQS_CATALOGUE_GROUP, 1000) - // .unwrap(); - - // let json_data = response - - // try to pull into - - // let expressions_a = vec![ - // "DB(JPMAQS,USD_EQXR_NSA,value)", - // "DB(JPMAQS,USD_EQXR_NSA,grading)", - // "DB(JPMAQS,USD_EQXR_NSA,eop_lag)", - // "DB(JPMAQS,USD_EQXR_NSA,mop_lag)", - // "DB(JPMAQS,GBP_EQXR_NSA,value)", - // "DB(JPMAQS,GBP_EQXR_NSA,grading)", - // "DB(JPMAQS,GBP_EQXR_NSA,eop_lag)", - // "DB(JPMAQS,GBP_EQXR_NSA,mop_lag)", - // ]; - - // let response = requester - // .get_timeseries_with_defaults(expressions_a.iter().map(|s| *s).collect()) - // .unwrap(); - - // let json_data = response.text().unwrap(); - // println!("{}", json_data); -} diff --git a/src/download/timeseries.rs b/src/download/timeseries.rs index 6835fa8..8880485 100644 --- a/src/download/timeseries.rs +++ b/src/download/timeseries.rs @@ -474,54 +474,3 @@ fn save_qdf_to_csv(qdf: &mut DataFrame, filename: &str) -> Result<(), Box println!("{:?}", df), - Err(e) => println!("Failed to create DataFrame: {:?}", e), - } - } -} diff --git a/src/main.rs b/src/main.rs index ca85d91..87dbd98 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,8 +1,5 @@ use msyrs::download::jpmaqsdownload::{JPMaQSDownload, JPMaQSDownloadGetIndicatorArgs}; -use polars::error::PolarsError; -use polars::export::chrono::NaiveDate; use polars::prelude::*; -use polars::series::Series; fn main() { println!("Authentication to DataQuery API"); @@ -20,7 +17,7 @@ fn main() { start.elapsed() ); - let num_ticks = 1000; + let num_ticks = 5000; let sel_tickers: Vec = tickers .iter() .take(num_ticks) @@ -43,6 +40,11 @@ fn main() { start.elapsed() ); + // sleep for 10 seconds + println!("Sleeping for 10 seconds..."); + std::thread::sleep(std::time::Duration::from_secs(10)); + println!("concatting to mega DataFrame"); + start = std::time::Instant::now(); // let mut qdf_list = Vec::new(); let mega_df = indicators @@ -51,21 +53,12 @@ fn main() { .fold(DataFrame::new(vec![]).unwrap(), |acc, df| { acc.vstack(&df).unwrap() }); - // for indicator in indicators { - // // df_deets.push((indicator.ticker.clone(), indicator.as_qdf().unwrap().size())); - // // println!( - // // "Ticker: {}, DataFrame", - // // indicator.ticker, - // // // indicator.as_qdf().unwrap() - // // ); - // qdf_list.push(indicator.as_qdf().unwrap()); - // } - - // // vstack the DataFrames - // let qdf = DataFrame::new(qdf_list).unwrap(); - // print mega_df size - println!("Mega DataFrame size: {:?}", mega_df.size()); + // + let es = mega_df.estimated_size(); + let es_mb = es as f64 / 1_048_576.0; + println!("Estimated size of DataFrame: {:.2} MB", es_mb); + println!("Sleeping for 10 seconds..."); println!( "Converted indicators to DataFrames in {:?}", start.elapsed()