This commit is contained in:
Palash Tyagi 2024-11-11 17:11:00 +00:00
parent d506fd23c4
commit 289cd9cc0b
4 changed files with 211 additions and 22 deletions

View File

@ -1,6 +1,6 @@
use crate::download::oauth_client::OAuthClient;
use crate::download::requester::DQRequester;
use crate::download::requester::DQTimeseriesRequestArgs;
use crate::download::timeseries::DQTimeseriesRequestArgs;
use crate::download::timeseries::DQTimeSeriesResponse;
use crate::download::timeseries::JPMaQSIndicator;
use rayon::prelude::*;
@ -95,7 +95,7 @@ impl JPMaQSDownload {
let dqts_vec = self.requester.get_timeseries(DQTimeseriesRequestArgs {
expressions: expressions,
..Default::default()
})?;
}).unwrap();
Ok(dqts_vec)
}

139
src/download/parreq.rs Normal file
View File

@ -0,0 +1,139 @@
use crate::download::oauth_client::OAuthClient;
use crate::download::timeseries::DQTimeseriesRequestArgs;
use futures::future;
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
use std::error::Error;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::runtime::Builder;
use tokio::sync::Semaphore;
use tokio::task;
const API_BASE_URL: &str = "https://api-developer.jpmorgan.com/research/dataquery-authe/api/v2";
const TIMESERIES_ENDPOINT: &str = "/expressions/time-series";
#[derive(Clone)]
pub struct ParallelRequester {
oauth_client: Arc<Mutex<OAuthClient>>,
max_threads: usize,
}
impl ParallelRequester {
pub fn new(oauth_client: OAuthClient) -> Self {
ParallelRequester {
oauth_client: Arc::new(Mutex::new(oauth_client)),
max_threads: 250,
}
}
fn build_headers(&self) -> Result<HeaderMap, Box<dyn Error>> {
let headers_map = self.oauth_client.lock().unwrap().get_headers()?;
let mut headers = HeaderMap::new();
for (key, value) in headers_map {
let header_name = HeaderName::from_bytes(key.as_bytes())?;
let header_value = HeaderValue::from_str(&value)?;
headers.insert(header_name, header_value);
}
Ok(headers)
}
async fn _request_async(&self, endpoint: &str) -> Result<String, Box<dyn Error>> {
let headers = self.build_headers()?;
let url = format!("{}{}", API_BASE_URL, endpoint);
let response = reqwest::Client::new()
.get(&url)
.headers(headers)
.send()
.await?;
if response.status().is_success() {
let text = response.text().await?;
Ok(text)
} else {
Err(Box::new(response.error_for_status().unwrap_err()))
}
}
pub fn request_expressions(
&mut self,
args: DQTimeseriesRequestArgs,
max_retries: u32,
) -> Result<Vec<String>, Box<dyn Error>> {
let expression_batches: Vec<Vec<String>> = args
.expressions
.chunks(20)
.map(|chunk| {
let mut vec = chunk.to_vec();
vec.sort();
vec
})
.collect();
let okay_response_texts = Arc::new(Mutex::new(Vec::new()));
let failed_batches = Arc::new(Mutex::new(Vec::new()));
let self_arc = Arc::new(self.clone());
let runtime = Builder::new_current_thread().enable_all().build()?;
let semaphore = Arc::new(Semaphore::new(self.max_threads));
runtime.block_on(async {
let mut tasks = vec![];
let mut curr_batch = 0;
let total_batches = expression_batches.len();
for batch in expression_batches.into_iter() {
let self_arc = Arc::clone(&self_arc);
let okay_response_texts = Arc::clone(&okay_response_texts);
let failed_batches = Arc::clone(&failed_batches);
let mut args_clone = args.clone();
args_clone.update_expressions(batch.clone());
let ep = format!("{}?{}", TIMESERIES_ENDPOINT, args_clone.as_query_string());
let permit = semaphore.clone().acquire_owned().await.unwrap();
println!("Requesting batch {} of {}", curr_batch + 1, total_batches);
let task = task::spawn(async move {
let _permit = permit; // Keep the permit alive until the end of the task
let res_str = self_arc._request_async(&ep).await;
match res_str {
Ok(text) => {
okay_response_texts.lock().unwrap().push(text);
println!("Batch {} of {} successful", curr_batch, total_batches);
}
Err(_) => {
failed_batches.lock().unwrap().push(batch);
}
}
});
tasks.push(task);
// Delay before starting the next task
tokio::time::sleep(Duration::from_millis(250)).await;
curr_batch += 1;
}
future::join_all(tasks).await;
});
// Retry failed batches if any, respecting the max_retries limit
let failed_batches = failed_batches.lock().unwrap().clone();
if !failed_batches.is_empty() && max_retries > 0 {
let mut new_args = args.clone();
let flattened_failed_batches: Vec<String> =
failed_batches.into_iter().flatten().collect();
new_args.update_expressions(flattened_failed_batches);
let retried_responses = self.request_expressions(new_args, max_retries - 1)?;
okay_response_texts
.lock()
.unwrap()
.extend(retried_responses);
} else if !failed_batches.is_empty() && max_retries == 0 {
return Err("Max retries reached".into());
}
// Collect and return successful responses
let final_responses = okay_response_texts.lock().unwrap().clone();
Ok(final_responses)
}
}

View File

@ -1,9 +1,16 @@
use crate::download::oauth_client::OAuthClient;
use crate::download::parreq::ParallelRequester;
use crate::download::timeseries::DQCatalogueResponse;
use crate::download::timeseries::DQCatalogueSingleResponse;
use crate::download::timeseries::DQTimeSeriesResponse;
use crate::download::timeseries::DQTimeseriesRequestArgs;
use crossbeam::channel;
use rayon::iter::IntoParallelRefMutIterator;
use rayon::iter::ParallelDrainRange;
use rayon::iter::{ParallelIterator, IntoParallelIterator};
use rayon::prelude::*;
// use futures::TryFutureExt;
use reqwest;
use reqwest::blocking::Client;
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
@ -144,18 +151,43 @@ impl DQRequester {
Ok(response)
}
pub fn get_timeseries(
&mut self,
args: DQTimeseriesRequestArgs,
) -> Result<Vec<DQTimeSeriesResponse>, Box<dyn Error>> {
let max_retries = 5;
println!(
"Invoking ParallelRequester for {:?} expressions",
args.expressions.len()
);
let mut pq = ParallelRequester::new(self.oauth_client.clone());
let pq_output = pq.request_expressions(args, max_retries);
let response_texts = match pq_output {
Ok(r) => r,
Err(e) => return Err(e),
};
// Parse the responses using rayon's parallel iterator
let dqts_vec: Vec<DQTimeSeriesResponse> = response_texts
.par_iter()
.filter_map(|rtext| serde_json::from_str::<DQTimeSeriesResponse>(rtext).ok())
.collect();
Ok(dqts_vec)
}
}
// pub fn get_timeseries(
// &mut self,
// args: DQTimeseriesRequestArgs,
// ) -> Result<Vec<DQTimeSeriesResponse>, Box<dyn Error>> {
// let max_retries = 5;
// println!(
// "Invoking recursive function for {:?} expressions",
// args.expressions.len()
// );
// _fetch_timeseries_recursive(self, args, max_retries)
// }
// pub fn get_timeseries(
// &mut self,
// args: DQTimeseriesRequestArgs,
// ) -> Result<Vec<DQTimeSeriesResponse>, Box<dyn Error>> {
// let max_retries = 5;
// println!(
// "Invoking recursive function for {:?} expressions",
// args.expressions.len()
// );
// _fetch_timeseries_recursive(self, args, max_retries)
// }
// }
fn _fetch_expression_batch(

View File

@ -1,4 +1,8 @@
use msyrs::download::jpmaqsdownload::{JPMaQSDownload, JPMaQSDownloadGetIndicatorArgs};
use polars::error::PolarsError;
use polars::export::chrono::NaiveDate;
use polars::prelude::*;
use polars::series::Series;
fn main() {
println!("Authentication to DataQuery API");
@ -16,7 +20,7 @@ fn main() {
start.elapsed()
);
let num_ticks = 50;
let num_ticks = 2500;
let sel_tickers: Vec<String> = tickers
.iter()
.take(num_ticks)
@ -40,14 +44,28 @@ fn main() {
);
start = std::time::Instant::now();
for indicator in indicators {
// df_deets.push((indicator.ticker.clone(), indicator.as_qdf().unwrap().size()));
println!(
"Ticker: {}, DataFrame: {}",
indicator.ticker,
indicator.as_qdf().unwrap()
);
}
// let mut qdf_list = Vec::new();
let mega_df = indicators
.iter()
.map(|indicator| indicator.as_qdf().unwrap())
.fold(DataFrame::new(vec![]).unwrap(), |acc, df| {
acc.vstack(&df).unwrap()
});
// for indicator in indicators {
// // df_deets.push((indicator.ticker.clone(), indicator.as_qdf().unwrap().size()));
// // println!(
// // "Ticker: {}, DataFrame",
// // indicator.ticker,
// // // indicator.as_qdf().unwrap()
// // );
// qdf_list.push(indicator.as_qdf().unwrap());
// }
// // vstack the DataFrames
// let qdf = DataFrame::new(qdf_list).unwrap();
// print mega_df size
println!("Mega DataFrame size: {:?}", mega_df.size());
println!(
"Converted indicators to DataFrames in {:?}",
start.elapsed()