This commit is contained in:
Palash Tyagi 2024-11-11 17:11:00 +00:00
parent d506fd23c4
commit 289cd9cc0b
4 changed files with 211 additions and 22 deletions

View File

@ -1,6 +1,6 @@
use crate::download::oauth_client::OAuthClient; use crate::download::oauth_client::OAuthClient;
use crate::download::requester::DQRequester; use crate::download::requester::DQRequester;
use crate::download::requester::DQTimeseriesRequestArgs; use crate::download::timeseries::DQTimeseriesRequestArgs;
use crate::download::timeseries::DQTimeSeriesResponse; use crate::download::timeseries::DQTimeSeriesResponse;
use crate::download::timeseries::JPMaQSIndicator; use crate::download::timeseries::JPMaQSIndicator;
use rayon::prelude::*; use rayon::prelude::*;
@ -95,7 +95,7 @@ impl JPMaQSDownload {
let dqts_vec = self.requester.get_timeseries(DQTimeseriesRequestArgs { let dqts_vec = self.requester.get_timeseries(DQTimeseriesRequestArgs {
expressions: expressions, expressions: expressions,
..Default::default() ..Default::default()
})?; }).unwrap();
Ok(dqts_vec) Ok(dqts_vec)
} }

139
src/download/parreq.rs Normal file
View File

@ -0,0 +1,139 @@
use crate::download::oauth_client::OAuthClient;
use crate::download::timeseries::DQTimeseriesRequestArgs;
use futures::future;
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
use std::error::Error;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::runtime::Builder;
use tokio::sync::Semaphore;
use tokio::task;
const API_BASE_URL: &str = "https://api-developer.jpmorgan.com/research/dataquery-authe/api/v2";
const TIMESERIES_ENDPOINT: &str = "/expressions/time-series";
#[derive(Clone)]
pub struct ParallelRequester {
oauth_client: Arc<Mutex<OAuthClient>>,
max_threads: usize,
}
impl ParallelRequester {
pub fn new(oauth_client: OAuthClient) -> Self {
ParallelRequester {
oauth_client: Arc::new(Mutex::new(oauth_client)),
max_threads: 250,
}
}
fn build_headers(&self) -> Result<HeaderMap, Box<dyn Error>> {
let headers_map = self.oauth_client.lock().unwrap().get_headers()?;
let mut headers = HeaderMap::new();
for (key, value) in headers_map {
let header_name = HeaderName::from_bytes(key.as_bytes())?;
let header_value = HeaderValue::from_str(&value)?;
headers.insert(header_name, header_value);
}
Ok(headers)
}
async fn _request_async(&self, endpoint: &str) -> Result<String, Box<dyn Error>> {
let headers = self.build_headers()?;
let url = format!("{}{}", API_BASE_URL, endpoint);
let response = reqwest::Client::new()
.get(&url)
.headers(headers)
.send()
.await?;
if response.status().is_success() {
let text = response.text().await?;
Ok(text)
} else {
Err(Box::new(response.error_for_status().unwrap_err()))
}
}
pub fn request_expressions(
&mut self,
args: DQTimeseriesRequestArgs,
max_retries: u32,
) -> Result<Vec<String>, Box<dyn Error>> {
let expression_batches: Vec<Vec<String>> = args
.expressions
.chunks(20)
.map(|chunk| {
let mut vec = chunk.to_vec();
vec.sort();
vec
})
.collect();
let okay_response_texts = Arc::new(Mutex::new(Vec::new()));
let failed_batches = Arc::new(Mutex::new(Vec::new()));
let self_arc = Arc::new(self.clone());
let runtime = Builder::new_current_thread().enable_all().build()?;
let semaphore = Arc::new(Semaphore::new(self.max_threads));
runtime.block_on(async {
let mut tasks = vec![];
let mut curr_batch = 0;
let total_batches = expression_batches.len();
for batch in expression_batches.into_iter() {
let self_arc = Arc::clone(&self_arc);
let okay_response_texts = Arc::clone(&okay_response_texts);
let failed_batches = Arc::clone(&failed_batches);
let mut args_clone = args.clone();
args_clone.update_expressions(batch.clone());
let ep = format!("{}?{}", TIMESERIES_ENDPOINT, args_clone.as_query_string());
let permit = semaphore.clone().acquire_owned().await.unwrap();
println!("Requesting batch {} of {}", curr_batch + 1, total_batches);
let task = task::spawn(async move {
let _permit = permit; // Keep the permit alive until the end of the task
let res_str = self_arc._request_async(&ep).await;
match res_str {
Ok(text) => {
okay_response_texts.lock().unwrap().push(text);
println!("Batch {} of {} successful", curr_batch, total_batches);
}
Err(_) => {
failed_batches.lock().unwrap().push(batch);
}
}
});
tasks.push(task);
// Delay before starting the next task
tokio::time::sleep(Duration::from_millis(250)).await;
curr_batch += 1;
}
future::join_all(tasks).await;
});
// Retry failed batches if any, respecting the max_retries limit
let failed_batches = failed_batches.lock().unwrap().clone();
if !failed_batches.is_empty() && max_retries > 0 {
let mut new_args = args.clone();
let flattened_failed_batches: Vec<String> =
failed_batches.into_iter().flatten().collect();
new_args.update_expressions(flattened_failed_batches);
let retried_responses = self.request_expressions(new_args, max_retries - 1)?;
okay_response_texts
.lock()
.unwrap()
.extend(retried_responses);
} else if !failed_batches.is_empty() && max_retries == 0 {
return Err("Max retries reached".into());
}
// Collect and return successful responses
let final_responses = okay_response_texts.lock().unwrap().clone();
Ok(final_responses)
}
}

View File

@ -1,9 +1,16 @@
use crate::download::oauth_client::OAuthClient; use crate::download::oauth_client::OAuthClient;
use crate::download::parreq::ParallelRequester;
use crate::download::timeseries::DQCatalogueResponse; use crate::download::timeseries::DQCatalogueResponse;
use crate::download::timeseries::DQCatalogueSingleResponse; use crate::download::timeseries::DQCatalogueSingleResponse;
use crate::download::timeseries::DQTimeSeriesResponse; use crate::download::timeseries::DQTimeSeriesResponse;
use crate::download::timeseries::DQTimeseriesRequestArgs; use crate::download::timeseries::DQTimeseriesRequestArgs;
use crossbeam::channel; use crossbeam::channel;
use rayon::iter::IntoParallelRefMutIterator;
use rayon::iter::ParallelDrainRange;
use rayon::iter::{ParallelIterator, IntoParallelIterator};
use rayon::prelude::*;
// use futures::TryFutureExt;
use reqwest; use reqwest;
use reqwest::blocking::Client; use reqwest::blocking::Client;
use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
@ -144,18 +151,43 @@ impl DQRequester {
Ok(response) Ok(response)
} }
pub fn get_timeseries(
&mut self,
args: DQTimeseriesRequestArgs,
) -> Result<Vec<DQTimeSeriesResponse>, Box<dyn Error>> {
let max_retries = 5;
println!(
"Invoking ParallelRequester for {:?} expressions",
args.expressions.len()
);
let mut pq = ParallelRequester::new(self.oauth_client.clone());
let pq_output = pq.request_expressions(args, max_retries);
let response_texts = match pq_output {
Ok(r) => r,
Err(e) => return Err(e),
};
// Parse the responses using rayon's parallel iterator
let dqts_vec: Vec<DQTimeSeriesResponse> = response_texts
.par_iter()
.filter_map(|rtext| serde_json::from_str::<DQTimeSeriesResponse>(rtext).ok())
.collect();
Ok(dqts_vec)
}
} }
// pub fn get_timeseries( // pub fn get_timeseries(
// &mut self, // &mut self,
// args: DQTimeseriesRequestArgs, // args: DQTimeseriesRequestArgs,
// ) -> Result<Vec<DQTimeSeriesResponse>, Box<dyn Error>> { // ) -> Result<Vec<DQTimeSeriesResponse>, Box<dyn Error>> {
// let max_retries = 5; // let max_retries = 5;
// println!( // println!(
// "Invoking recursive function for {:?} expressions", // "Invoking recursive function for {:?} expressions",
// args.expressions.len() // args.expressions.len()
// ); // );
// _fetch_timeseries_recursive(self, args, max_retries) // _fetch_timeseries_recursive(self, args, max_retries)
// } // }
// } // }
fn _fetch_expression_batch( fn _fetch_expression_batch(

View File

@ -1,4 +1,8 @@
use msyrs::download::jpmaqsdownload::{JPMaQSDownload, JPMaQSDownloadGetIndicatorArgs}; use msyrs::download::jpmaqsdownload::{JPMaQSDownload, JPMaQSDownloadGetIndicatorArgs};
use polars::error::PolarsError;
use polars::export::chrono::NaiveDate;
use polars::prelude::*;
use polars::series::Series;
fn main() { fn main() {
println!("Authentication to DataQuery API"); println!("Authentication to DataQuery API");
@ -16,7 +20,7 @@ fn main() {
start.elapsed() start.elapsed()
); );
let num_ticks = 50; let num_ticks = 2500;
let sel_tickers: Vec<String> = tickers let sel_tickers: Vec<String> = tickers
.iter() .iter()
.take(num_ticks) .take(num_ticks)
@ -40,14 +44,28 @@ fn main() {
); );
start = std::time::Instant::now(); start = std::time::Instant::now();
for indicator in indicators { // let mut qdf_list = Vec::new();
// df_deets.push((indicator.ticker.clone(), indicator.as_qdf().unwrap().size())); let mega_df = indicators
println!( .iter()
"Ticker: {}, DataFrame: {}", .map(|indicator| indicator.as_qdf().unwrap())
indicator.ticker, .fold(DataFrame::new(vec![]).unwrap(), |acc, df| {
indicator.as_qdf().unwrap() acc.vstack(&df).unwrap()
); });
} // for indicator in indicators {
// // df_deets.push((indicator.ticker.clone(), indicator.as_qdf().unwrap().size()));
// // println!(
// // "Ticker: {}, DataFrame",
// // indicator.ticker,
// // // indicator.as_qdf().unwrap()
// // );
// qdf_list.push(indicator.as_qdf().unwrap());
// }
// // vstack the DataFrames
// let qdf = DataFrame::new(qdf_list).unwrap();
// print mega_df size
println!("Mega DataFrame size: {:?}", mega_df.size());
println!( println!(
"Converted indicators to DataFrames in {:?}", "Converted indicators to DataFrames in {:?}",
start.elapsed() start.elapsed()