This commit is contained in:
Palash Tyagi 2024-11-12 02:01:35 +00:00
parent 289cd9cc0b
commit 91aa5762d6
5 changed files with 133 additions and 186 deletions

View File

@ -1,7 +1,8 @@
use crate::download::oauth_client::OAuthClient;
use crate::download::requester::DQRequester;
use crate::download::timeseries::DQTimeseriesRequestArgs;
use crate::download::timeseries::DQTimeSeries;
use crate::download::timeseries::DQTimeSeriesResponse;
use crate::download::timeseries::DQTimeseriesRequestArgs;
use crate::download::timeseries::JPMaQSIndicator;
use rayon::prelude::*;
use std::error::Error;
@ -92,10 +93,13 @@ impl JPMaQSDownload {
&mut self,
expressions: Vec<String>,
) -> Result<Vec<DQTimeSeriesResponse>, Box<dyn Error>> {
let dqts_vec = self.requester.get_timeseries(DQTimeseriesRequestArgs {
let dqts_vec = self
.requester
.get_timeseries(DQTimeseriesRequestArgs {
expressions: expressions,
..Default::default()
}).unwrap();
})
.unwrap();
Ok(dqts_vec)
}
@ -113,35 +117,50 @@ impl JPMaQSDownload {
let dqts_vec = self.get_expressions(expressions)?;
// println!("Retrieved {} time series", -- sum[dqts_vec.iter().map(|dqts| dqts.len())]);
println!(
"Retrieved {} time series",
dqts_vec
.iter()
.map(|dqts| dqts.list_expressions().len())
.sum::<usize>()
);
println!("Retrieved all time series",);
// say pausing for 30 seconds
println!("Pausing for 10 seconds");
std::thread::sleep(std::time::Duration::from_secs(10));
println!("Resuming");
let start = std::time::Instant::now();
// let indicators = dqts_vec
// .iter()
// .flat_map(|dqts| dqts.get_timeseries_by_ticker())
// .map(|tsv| JPMaQSIndicator::new(tsv))
// .collect::<Result<Vec<JPMaQSIndicator>, Box<dyn Error>>>()?;
let indicators: Vec<_> = dqts_vec
.into_par_iter() // Use into_par_iter() to parallelize and consume dqts_vec
.flat_map(|dqts| {
dqts.get_timeseries_by_ticker()
.into_par_iter()
.filter_map(|tsv| JPMaQSIndicator::new(tsv).ok())
})
.collect();
let indicators = dq_response_vec_to_jpmaqs_indicators(dqts_vec);
println!(
"Converted time series to indicators in {:?}",
start.elapsed()
);
println!("Pausing for 10 seconds");
std::thread::sleep(std::time::Duration::from_secs(10));
println!("Resuming");
Ok(indicators)
}
}
// fn dq_response_vec_to_jpmaqs_indicators(
// dqts_vec: Vec<DQTimeSeriesResponse>,
// ) -> Vec<JPMaQSIndicator> {
// let mut indicators: Vec<JPMaQSIndicator> = Vec::new();
// for dqts in dqts_vec {
// indicators.extend(
// dqts.consume_to_grouped_by_ticker() // moves the values to free up memory
// .into_iter()
// .filter_map(|tsv| JPMaQSIndicator::new(tsv).ok()),
// );
// }
// indicators
// }
fn dq_response_vec_to_jpmaqs_indicators(
dqts_vec: Vec<DQTimeSeriesResponse>,
) -> Vec<JPMaQSIndicator> {
dqts_vec
.into_par_iter()
.flat_map(|dqts| {
dqts.consume_to_grouped_by_ticker()
.into_iter()
.filter_map(|tsv| JPMaQSIndicator::new(tsv).ok())
.collect::<Vec<JPMaQSIndicator>>()
})
.collect()
}

View File

@ -79,7 +79,7 @@ impl ParallelRequester {
runtime.block_on(async {
let mut tasks = vec![];
let mut curr_batch = 0;
let mut curr_batch = 1;
let total_batches = expression_batches.len();
for batch in expression_batches.into_iter() {
let self_arc = Arc::clone(&self_arc);
@ -89,7 +89,10 @@ impl ParallelRequester {
args_clone.update_expressions(batch.clone());
let ep = format!("{}?{}", TIMESERIES_ENDPOINT, args_clone.as_query_string());
let permit = semaphore.clone().acquire_owned().await.unwrap();
println!("Requesting batch {} of {}", curr_batch + 1, total_batches);
// println!("Requesting batch {} of {}", curr_batch + 1, total_batches);
if curr_batch % 10 == 0 {
println!("Requesting batch {} of {}", curr_batch, total_batches);
}
let task = task::spawn(async move {
let _permit = permit; // Keep the permit alive until the end of the task
let res_str = self_arc._request_async(&ep).await;

View File

@ -4,21 +4,12 @@ use crate::download::timeseries::DQCatalogueResponse;
use crate::download::timeseries::DQCatalogueSingleResponse;
use crate::download::timeseries::DQTimeSeriesResponse;
use crate::download::timeseries::DQTimeseriesRequestArgs;
use crossbeam::channel;
use rayon::iter::IntoParallelRefMutIterator;
use rayon::iter::ParallelDrainRange;
use rayon::iter::{ParallelIterator, IntoParallelIterator};
use rayon::prelude::*;
// use futures::TryFutureExt;
use crate::download::timeseries::JPMaQSIndicator;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use reqwest;
use reqwest::blocking::Client;
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
// use std::collections::HashMap;
use std::error::Error;
use std::sync::{Arc, Mutex};
use std::thread;
use std::time::Duration;
const API_BASE_URL: &str = "https://api-developer.jpmorgan.com/research/dataquery-authe/api/v2";
const HEARTBEAT_ENDPOINT: &str = "/services/heartbeat";
@ -26,7 +17,7 @@ const TIMESERIES_ENDPOINT: &str = "/expressions/time-series";
const CATALOGUE_ENDPOINT: &str = "/group/instruments";
const API_DELAY_MILLIS: u64 = 200;
const MAX_THREAD_WORKERS: usize = 100;
// const MAX_THREAD_WORKERS: usize = 100;
// const JPMAQS_CATALOGUE_GROUP: &str = "JPMAQS";
#[derive(Debug, Clone)]
@ -161,159 +152,42 @@ impl DQRequester {
args.expressions.len()
);
let mut pq = ParallelRequester::new(self.oauth_client.clone());
let pq_output = pq.request_expressions(args, max_retries);
let start = std::time::Instant::now();
let response_texts = match pq_output {
let response_texts = match pq.request_expressions(args, max_retries) {
Ok(r) => r,
Err(e) => return Err(e),
};
println!(
"Time elapsed for pq.request_expressions: {:?}",
start.elapsed()
);
// Parse the responses using rayon's parallel iterator
let dqts_vec: Vec<DQTimeSeriesResponse> = response_texts
.par_iter()
.filter_map(|rtext| serde_json::from_str::<DQTimeSeriesResponse>(rtext).ok())
.collect();
let dqts_vec: Vec<DQTimeSeriesResponse> = parse_response_texts(response_texts);
// Sleep for 10 seconds
println!("Pausing for 10 seconds");
std::thread::sleep(std::time::Duration::from_secs(10));
println!("Resuming");
Ok(dqts_vec)
}
}
// pub fn get_timeseries(
// &mut self,
// args: DQTimeseriesRequestArgs,
// ) -> Result<Vec<DQTimeSeriesResponse>, Box<dyn Error>> {
// let max_retries = 5;
// println!(
// "Invoking recursive function for {:?} expressions",
// args.expressions.len()
// );
// _fetch_timeseries_recursive(self, args, max_retries)
// }
// }
fn _fetch_expression_batch(
client: Arc<Mutex<DQRequester>>,
expr_batch: Vec<String>,
okay_responses: Arc<Mutex<Vec<DQTimeSeriesResponse>>>,
failed_batches: Arc<Mutex<Vec<Vec<String>>>>,
args: DQTimeseriesRequestArgs,
) {
let response = client.lock().unwrap()._fetch_single_timeseries_batch(args);
match response {
Ok(r) => {
// Attempt to parse the response text
match serde_json::from_str::<DQTimeSeriesResponse>(&r.text().unwrap()) {
Ok(dq_response) => {
okay_responses.lock().unwrap().push(dq_response);
log::info!("Got batch: {:?}", expr_batch);
}
Err(e) => {
// If parsing fails, treat this as a batch failure
failed_batches.lock().unwrap().push(expr_batch.clone());
log::error!("Failed to parse timeseries: {:?} : {:?}", expr_batch, e);
fn parse_response_texts(response_texts: Vec<String>) -> Vec<DQTimeSeriesResponse> {
response_texts
.into_par_iter()
.filter_map(|rtext| {
// Attempt to deserialize and immediately release rtext if successful
match serde_json::from_str::<DQTimeSeriesResponse>(&rtext) {
Ok(dqts) => Some(dqts),
Err(err) => {
eprintln!("Failed to deserialize response: {}", err);
None
}
}
}
Err(e) => {
// Handle _fetch_single_timeseries_batch error
failed_batches.lock().unwrap().push(expr_batch.clone());
log::error!("Failed to get batch: {:?} : {:?}", expr_batch, e);
}
}
}
fn _fetch_timeseries_recursive(
dq_requester: &mut DQRequester,
args: DQTimeseriesRequestArgs,
max_retries: u32,
) -> Result<Vec<DQTimeSeriesResponse>, Box<dyn Error>> {
// sort to ensure that the order of expressions is consistent
let expression_batches: Vec<Vec<String>> = args
.expressions
.chunks(20)
.map(|chunk| {
let mut vec = chunk.to_vec();
vec.sort();
vec
})
.collect();
let okay_responses = Arc::new(Mutex::new(Vec::new()));
let failed_batches = Arc::new(Mutex::new(Vec::new()));
let client = Arc::new(Mutex::new(dq_requester.clone()));
let (sender, receiver) = channel::unbounded();
let total_batches = expression_batches.len();
let mut curr_batch = 0;
// Spawn 20 worker threads
let mut workers = vec![];
for _ in 0..MAX_THREAD_WORKERS {
let receiver = receiver.clone();
let okay_responses = Arc::clone(&okay_responses);
let failed_batches = Arc::clone(&failed_batches);
let client = Arc::clone(&client);
let worker = thread::spawn(move || {
while let Ok((args, expr_batch)) = receiver.recv() {
_fetch_expression_batch(
client.clone(),
expr_batch,
okay_responses.clone(),
failed_batches.clone(),
args,
);
.collect()
}
});
workers.push(worker);
}
// Send jobs to workers
for expr_batch in expression_batches {
curr_batch += 1;
let mut args = args.clone();
args.update_expressions(expr_batch.clone());
log::info!("Processed {} out of {} batches", curr_batch, total_batches);
thread::sleep(Duration::from_millis(API_DELAY_MILLIS));
sender.send((args, expr_batch)).unwrap();
}
drop(sender); // Close the channel so workers can finish
// Wait for all workers to finish
for worker in workers {
worker.join().unwrap();
}
let mut okay_responses = Arc::try_unwrap(okay_responses)
.unwrap()
.into_inner()
.unwrap();
let failed_batches = Arc::try_unwrap(failed_batches)
.unwrap()
.into_inner()
.unwrap();
if !failed_batches.is_empty() && max_retries == 0 {
return Err("Max retries reached".into());
}
if !failed_batches.is_empty() && max_retries > 0 {
log::info!("Retrying failed batches");
let mut new_args = args.clone();
new_args.update_expressions(failed_batches.concat());
log::info!("Retrying with {} failed expressions", failed_batches.len());
let mut retry_responses =
_fetch_timeseries_recursive(dq_requester, new_args, max_retries - 1)?;
okay_responses.append(&mut retry_responses);
}
log::info!("Returning {} responses", okay_responses.len());
Ok(okay_responses)
}
#[allow(dead_code)]
fn main() {
let client_id = std::env::var("DQ_CLIENT_ID").unwrap();

View File

@ -8,9 +8,6 @@ use std::collections::HashSet;
use std::error::Error;
use std::fs::File;
#[derive(Debug, Clone)]
pub struct DQTimeseriesRequestArgs {
pub start_date: String,
@ -79,8 +76,6 @@ impl Default for DQTimeseriesRequestArgs {
}
}
/// Response from the DataQuery API
#[derive(Deserialize, Debug)]
pub struct DQTimeSeriesResponse {
@ -144,6 +139,37 @@ struct Attribute {
time_series: Vec<(String, Option<f64>)>,
}
impl Attribute {
/// Get the ticker from the expression
pub fn get_ticker(&self) -> Result<String, Box<dyn Error>> {
if !self.expression.starts_with("DB(JPMAQS,") {
return Err("Expression does not start with 'DB(JPMAQS,'".into());
}
let ticker = self.expression.split(',').nth(1).unwrap();
if ticker.is_empty() {
return Err("Ticker is empty".into());
}
Ok(ticker.to_string())
}
/// Get the metric from the expression
pub fn get_metric(&self) -> Result<String, Box<dyn Error>> {
if !self.expression.starts_with("DB(JPMAQS,") {
return Err("Expression does not start with 'DB(JPMAQS,'".into());
}
let metric = self
.expression
.trim_end_matches(')')
.split(',')
.last()
.unwrap();
if metric.is_empty() {
return Err("Metric is empty".into());
}
Ok(metric.to_string())
}
}
/// Representation of a single time series
#[derive(Debug)]
pub struct DQTimeSeries {
@ -253,6 +279,31 @@ impl DQTimeSeriesResponse {
timeseries_by_ticker.into_iter().map(|(_, v)| v).collect()
}
/// Consume the DQTimeSeriesResponse by grouping the time series by ticker.
/// This function can only be called once as it takes ownership of the data.
pub fn consume_to_grouped_by_ticker(mut self) -> Vec<Vec<DQTimeSeries>> {
// Take the instruments vector, leaving an empty one in its place.
let instruments = std::mem::take(&mut self.instruments);
// Group time series by ticker
let mut timeseries_by_ticker: HashMap<String, Vec<DQTimeSeries>> = HashMap::new();
for instrument in instruments {
for attribute in instrument.attributes {
let ticker = attribute.get_ticker().unwrap_or_default();
timeseries_by_ticker
.entry(ticker)
.or_default()
.push(DQTimeSeries {
expression: attribute.expression,
time_series: attribute.time_series,
});
}
}
// Convert the HashMap into a Vec of Vecs
timeseries_by_ticker.into_iter().map(|(_, v)| v).collect()
}
}
impl JPMaQSIndicator {

View File

@ -20,7 +20,7 @@ fn main() {
start.elapsed()
);
let num_ticks = 2500;
let num_ticks = 1000;
let sel_tickers: Vec<String> = tickers
.iter()
.take(num_ticks)