From 289cd9cc0b18d6b4f7095d818d5adf2c65c46edd Mon Sep 17 00:00:00 2001 From: Palash Tyagi <23239946+Magnus167@users.noreply.github.com> Date: Mon, 11 Nov 2024 17:11:00 +0000 Subject: [PATCH] working! --- src/download/jpmaqsdownload.rs | 4 +- src/download/parreq.rs | 139 +++++++++++++++++++++++++++++++++ src/download/requester.rs | 54 ++++++++++--- src/main.rs | 36 ++++++--- 4 files changed, 211 insertions(+), 22 deletions(-) create mode 100644 src/download/parreq.rs diff --git a/src/download/jpmaqsdownload.rs b/src/download/jpmaqsdownload.rs index c837597..ced37eb 100644 --- a/src/download/jpmaqsdownload.rs +++ b/src/download/jpmaqsdownload.rs @@ -1,6 +1,6 @@ use crate::download::oauth_client::OAuthClient; use crate::download::requester::DQRequester; -use crate::download::requester::DQTimeseriesRequestArgs; +use crate::download::timeseries::DQTimeseriesRequestArgs; use crate::download::timeseries::DQTimeSeriesResponse; use crate::download::timeseries::JPMaQSIndicator; use rayon::prelude::*; @@ -95,7 +95,7 @@ impl JPMaQSDownload { let dqts_vec = self.requester.get_timeseries(DQTimeseriesRequestArgs { expressions: expressions, ..Default::default() - })?; + }).unwrap(); Ok(dqts_vec) } diff --git a/src/download/parreq.rs b/src/download/parreq.rs new file mode 100644 index 0000000..8e9d1ef --- /dev/null +++ b/src/download/parreq.rs @@ -0,0 +1,139 @@ +use crate::download::oauth_client::OAuthClient; +use crate::download::timeseries::DQTimeseriesRequestArgs; +use futures::future; +use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; +use std::error::Error; +use std::sync::{Arc, Mutex}; +use std::time::Duration; +use tokio::runtime::Builder; +use tokio::sync::Semaphore; +use tokio::task; + +const API_BASE_URL: &str = "https://api-developer.jpmorgan.com/research/dataquery-authe/api/v2"; +const TIMESERIES_ENDPOINT: &str = "/expressions/time-series"; + +#[derive(Clone)] +pub struct ParallelRequester { + oauth_client: Arc>, + max_threads: usize, +} + +impl ParallelRequester { + pub fn new(oauth_client: OAuthClient) -> Self { + ParallelRequester { + oauth_client: Arc::new(Mutex::new(oauth_client)), + max_threads: 250, + } + } + + fn build_headers(&self) -> Result> { + let headers_map = self.oauth_client.lock().unwrap().get_headers()?; + let mut headers = HeaderMap::new(); + for (key, value) in headers_map { + let header_name = HeaderName::from_bytes(key.as_bytes())?; + let header_value = HeaderValue::from_str(&value)?; + headers.insert(header_name, header_value); + } + Ok(headers) + } + + async fn _request_async(&self, endpoint: &str) -> Result> { + let headers = self.build_headers()?; + let url = format!("{}{}", API_BASE_URL, endpoint); + + let response = reqwest::Client::new() + .get(&url) + .headers(headers) + .send() + .await?; + + if response.status().is_success() { + let text = response.text().await?; + Ok(text) + } else { + Err(Box::new(response.error_for_status().unwrap_err())) + } + } + + pub fn request_expressions( + &mut self, + args: DQTimeseriesRequestArgs, + max_retries: u32, + ) -> Result, Box> { + let expression_batches: Vec> = args + .expressions + .chunks(20) + .map(|chunk| { + let mut vec = chunk.to_vec(); + vec.sort(); + vec + }) + .collect(); + + let okay_response_texts = Arc::new(Mutex::new(Vec::new())); + let failed_batches = Arc::new(Mutex::new(Vec::new())); + let self_arc = Arc::new(self.clone()); + + let runtime = Builder::new_current_thread().enable_all().build()?; + let semaphore = Arc::new(Semaphore::new(self.max_threads)); + + runtime.block_on(async { + let mut tasks = vec![]; + let mut curr_batch = 0; + let total_batches = expression_batches.len(); + for batch in expression_batches.into_iter() { + let self_arc = Arc::clone(&self_arc); + let okay_response_texts = Arc::clone(&okay_response_texts); + let failed_batches = Arc::clone(&failed_batches); + let mut args_clone = args.clone(); + args_clone.update_expressions(batch.clone()); + let ep = format!("{}?{}", TIMESERIES_ENDPOINT, args_clone.as_query_string()); + let permit = semaphore.clone().acquire_owned().await.unwrap(); + println!("Requesting batch {} of {}", curr_batch + 1, total_batches); + let task = task::spawn(async move { + let _permit = permit; // Keep the permit alive until the end of the task + let res_str = self_arc._request_async(&ep).await; + + match res_str { + Ok(text) => { + okay_response_texts.lock().unwrap().push(text); + println!("Batch {} of {} successful", curr_batch, total_batches); + } + Err(_) => { + failed_batches.lock().unwrap().push(batch); + } + } + }); + + tasks.push(task); + + // Delay before starting the next task + tokio::time::sleep(Duration::from_millis(250)).await; + curr_batch += 1; + } + + future::join_all(tasks).await; + }); + + // Retry failed batches if any, respecting the max_retries limit + let failed_batches = failed_batches.lock().unwrap().clone(); + if !failed_batches.is_empty() && max_retries > 0 { + let mut new_args = args.clone(); + let flattened_failed_batches: Vec = + failed_batches.into_iter().flatten().collect(); + new_args.update_expressions(flattened_failed_batches); + + let retried_responses = self.request_expressions(new_args, max_retries - 1)?; + okay_response_texts + .lock() + .unwrap() + .extend(retried_responses); + } else if !failed_batches.is_empty() && max_retries == 0 { + return Err("Max retries reached".into()); + } + + // Collect and return successful responses + let final_responses = okay_response_texts.lock().unwrap().clone(); + Ok(final_responses) + } +} diff --git a/src/download/requester.rs b/src/download/requester.rs index 076424e..ece0b21 100644 --- a/src/download/requester.rs +++ b/src/download/requester.rs @@ -1,9 +1,16 @@ use crate::download::oauth_client::OAuthClient; +use crate::download::parreq::ParallelRequester; use crate::download::timeseries::DQCatalogueResponse; use crate::download::timeseries::DQCatalogueSingleResponse; use crate::download::timeseries::DQTimeSeriesResponse; use crate::download::timeseries::DQTimeseriesRequestArgs; use crossbeam::channel; +use rayon::iter::IntoParallelRefMutIterator; +use rayon::iter::ParallelDrainRange; +use rayon::iter::{ParallelIterator, IntoParallelIterator}; + +use rayon::prelude::*; +// use futures::TryFutureExt; use reqwest; use reqwest::blocking::Client; use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; @@ -144,18 +151,43 @@ impl DQRequester { Ok(response) } + pub fn get_timeseries( + &mut self, + args: DQTimeseriesRequestArgs, + ) -> Result, Box> { + let max_retries = 5; + println!( + "Invoking ParallelRequester for {:?} expressions", + args.expressions.len() + ); + let mut pq = ParallelRequester::new(self.oauth_client.clone()); + let pq_output = pq.request_expressions(args, max_retries); + + let response_texts = match pq_output { + Ok(r) => r, + Err(e) => return Err(e), + }; + + // Parse the responses using rayon's parallel iterator + let dqts_vec: Vec = response_texts + .par_iter() + .filter_map(|rtext| serde_json::from_str::(rtext).ok()) + .collect(); + + Ok(dqts_vec) + } } - // pub fn get_timeseries( - // &mut self, - // args: DQTimeseriesRequestArgs, - // ) -> Result, Box> { - // let max_retries = 5; - // println!( - // "Invoking recursive function for {:?} expressions", - // args.expressions.len() - // ); - // _fetch_timeseries_recursive(self, args, max_retries) - // } +// pub fn get_timeseries( +// &mut self, +// args: DQTimeseriesRequestArgs, +// ) -> Result, Box> { +// let max_retries = 5; +// println!( +// "Invoking recursive function for {:?} expressions", +// args.expressions.len() +// ); +// _fetch_timeseries_recursive(self, args, max_retries) +// } // } fn _fetch_expression_batch( diff --git a/src/main.rs b/src/main.rs index eb3f278..17b2fab 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,8 @@ use msyrs::download::jpmaqsdownload::{JPMaQSDownload, JPMaQSDownloadGetIndicatorArgs}; +use polars::error::PolarsError; +use polars::export::chrono::NaiveDate; +use polars::prelude::*; +use polars::series::Series; fn main() { println!("Authentication to DataQuery API"); @@ -16,7 +20,7 @@ fn main() { start.elapsed() ); - let num_ticks = 50; + let num_ticks = 2500; let sel_tickers: Vec = tickers .iter() .take(num_ticks) @@ -40,14 +44,28 @@ fn main() { ); start = std::time::Instant::now(); - for indicator in indicators { - // df_deets.push((indicator.ticker.clone(), indicator.as_qdf().unwrap().size())); - println!( - "Ticker: {}, DataFrame: {}", - indicator.ticker, - indicator.as_qdf().unwrap() - ); - } + // let mut qdf_list = Vec::new(); + let mega_df = indicators + .iter() + .map(|indicator| indicator.as_qdf().unwrap()) + .fold(DataFrame::new(vec![]).unwrap(), |acc, df| { + acc.vstack(&df).unwrap() + }); + // for indicator in indicators { + // // df_deets.push((indicator.ticker.clone(), indicator.as_qdf().unwrap().size())); + // // println!( + // // "Ticker: {}, DataFrame", + // // indicator.ticker, + // // // indicator.as_qdf().unwrap() + // // ); + // qdf_list.push(indicator.as_qdf().unwrap()); + // } + + // // vstack the DataFrames + // let qdf = DataFrame::new(qdf_list).unwrap(); + // print mega_df size + println!("Mega DataFrame size: {:?}", mega_df.size()); + println!( "Converted indicators to DataFrames in {:?}", start.elapsed()