mirror of
https://github.com/Magnus167/msyrs.git
synced 2025-08-20 13:00:01 +00:00
working!
This commit is contained in:
parent
ab30aa2380
commit
ee19862036
556
Cargo.lock
generated
556
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
63
Cargo.toml
63
Cargo.toml
@ -16,7 +16,8 @@ reqwest = { version = "0.12.9", features = ["blocking", "json"] }
|
|||||||
serde_json = "1.0"
|
serde_json = "1.0"
|
||||||
serde_urlencoded = "0.7"
|
serde_urlencoded = "0.7"
|
||||||
serde = { version = "1.0.215", features = ["derive"] }
|
serde = { version = "1.0.215", features = ["derive"] }
|
||||||
polars = { version = "0.44.2", features = ["lazy"] }
|
# polars = { version = "0.44.2", features = ["lazy"] }
|
||||||
|
chrono = "0.4.38"
|
||||||
rand = "0.8"
|
rand = "0.8"
|
||||||
threadpool = "1.8.1"
|
threadpool = "1.8.1"
|
||||||
log = "0.4.22"
|
log = "0.4.22"
|
||||||
@ -24,3 +25,63 @@ crossbeam = "0.8"
|
|||||||
rayon = "1.5"
|
rayon = "1.5"
|
||||||
tokio = "1.41.1"
|
tokio = "1.41.1"
|
||||||
futures = "0.3"
|
futures = "0.3"
|
||||||
|
polars = { version = "^0.44.0", features = [
|
||||||
|
"lazy",
|
||||||
|
"temporal",
|
||||||
|
"describe",
|
||||||
|
"json",
|
||||||
|
"parquet",
|
||||||
|
"dtype-datetime",
|
||||||
|
"strings",
|
||||||
|
"timezones",
|
||||||
|
"ndarray",
|
||||||
|
"concat_str",
|
||||||
|
|
||||||
|
# "serde-lazy",
|
||||||
|
# "parquet",
|
||||||
|
# "decompress",
|
||||||
|
# "zip",
|
||||||
|
# "gzip",
|
||||||
|
"dynamic_group_by",
|
||||||
|
"rows",
|
||||||
|
"cross_join",
|
||||||
|
"semi_anti_join",
|
||||||
|
"row_hash",
|
||||||
|
"diagonal_concat",
|
||||||
|
"dataframe_arithmetic",
|
||||||
|
"partition_by",
|
||||||
|
"is_in",
|
||||||
|
"zip_with",
|
||||||
|
"round_series",
|
||||||
|
"repeat_by",
|
||||||
|
"is_first_distinct",
|
||||||
|
"is_last_distinct",
|
||||||
|
"checked_arithmetic",
|
||||||
|
"dot_product",
|
||||||
|
"concat_str",
|
||||||
|
"reinterpret",
|
||||||
|
"take_opt_iter",
|
||||||
|
"mode",
|
||||||
|
"cum_agg",
|
||||||
|
"rolling_window",
|
||||||
|
"interpolate",
|
||||||
|
"rank",
|
||||||
|
"moment",
|
||||||
|
"ewma",
|
||||||
|
"abs",
|
||||||
|
"product",
|
||||||
|
"diff",
|
||||||
|
"pct_change",
|
||||||
|
"unique_counts",
|
||||||
|
"log",
|
||||||
|
"list_to_struct",
|
||||||
|
"list_count",
|
||||||
|
"list_eval",
|
||||||
|
"cumulative_eval",
|
||||||
|
"arg_where",
|
||||||
|
"search_sorted",
|
||||||
|
"offset_by",
|
||||||
|
"trigonometry",
|
||||||
|
"sign",
|
||||||
|
"propagate_nans",
|
||||||
|
] }
|
||||||
|
37
src/main.rs
37
src/main.rs
@ -1,6 +1,8 @@
|
|||||||
use msyrs::download::jpmaqsdownload::{JPMaQSDownload, JPMaQSDownloadGetIndicatorArgs};
|
use msyrs::download::jpmaqsdownload::{JPMaQSDownload, JPMaQSDownloadGetIndicatorArgs};
|
||||||
use msyrs::utils::dftools::is_quantamental_dataframe;
|
use msyrs::utils::dftools as msyrs_dftools;
|
||||||
fn main() {
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
fn download_stuff() {
|
||||||
println!("Authentication to DataQuery API");
|
println!("Authentication to DataQuery API");
|
||||||
let mut jpamqs_download = JPMaQSDownload::default();
|
let mut jpamqs_download = JPMaQSDownload::default();
|
||||||
|
|
||||||
@ -25,28 +27,6 @@ fn main() {
|
|||||||
// let mut df_deets = Vec::new();
|
// let mut df_deets = Vec::new();
|
||||||
|
|
||||||
println!("Retrieving indicators for {} tickers", sel_tickers.len());
|
println!("Retrieving indicators for {} tickers", sel_tickers.len());
|
||||||
// start = std::time::Instant::now();
|
|
||||||
// let all_metrics: Vec<String> = ["value", "grading", "eop_lag", "mop_lag"]
|
|
||||||
// .iter()
|
|
||||||
// .map(|x| x.to_string())
|
|
||||||
// .collect();
|
|
||||||
// let res = jpamqs_download.save_indicators_as_csv(
|
|
||||||
// JPMaQSDownloadGetIndicatorArgs {
|
|
||||||
// tickers: sel_tickers.clone(),
|
|
||||||
// metrics: all_metrics,
|
|
||||||
// ..Default::default()
|
|
||||||
// },
|
|
||||||
// "./data/",
|
|
||||||
// );
|
|
||||||
|
|
||||||
// match res {
|
|
||||||
// Ok(_) => println!(
|
|
||||||
// "Saved indicators for {} tickers in {:?}",
|
|
||||||
// sel_tickers.len(),
|
|
||||||
// start.elapsed()
|
|
||||||
// ),
|
|
||||||
// Err(e) => println!("Error saving indicators: {:?}", e),
|
|
||||||
// }
|
|
||||||
|
|
||||||
let res_df = jpamqs_download
|
let res_df = jpamqs_download
|
||||||
.get_indicators_qdf(JPMaQSDownloadGetIndicatorArgs {
|
.get_indicators_qdf(JPMaQSDownloadGetIndicatorArgs {
|
||||||
@ -62,9 +42,16 @@ fn main() {
|
|||||||
start.elapsed()
|
start.elapsed()
|
||||||
);
|
);
|
||||||
|
|
||||||
if !is_quantamental_dataframe(&res_df) {
|
if !msyrs_dftools::is_quantamental_dataframe(&res_df) {
|
||||||
println!("DataFrame is not a quantamental DataFrame");
|
println!("DataFrame is not a quantamental DataFrame");
|
||||||
} else {
|
} else {
|
||||||
println!("DataFrame is a quantamental DataFrame");
|
println!("DataFrame is a quantamental DataFrame");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
// E:\Work\ruzt\msyrs\data\JPMaQSData\ALLIFCDSGDP\AUD_ALLIFCDSGDP_NSA.csv
|
||||||
|
let pth = "E:/Work/ruzt/msyrs/data/JPMaQSData/ALLIFCDSGDP/AUD_ALLIFCDSGDP_NSA.csv";
|
||||||
|
let df = msyrs_dftools::load_quantamental_dataframe(pth).unwrap();
|
||||||
|
println!("{:?}", df);
|
||||||
|
}
|
||||||
|
@ -1,4 +1,7 @@
|
|||||||
|
use crate::utils::misc::*;
|
||||||
|
use polars::datatypes::DataType;
|
||||||
use polars::prelude::*;
|
use polars::prelude::*;
|
||||||
|
use std::error::Error;
|
||||||
|
|
||||||
/// The standard metrics provided by JPMaQS (`value`, `grading`, `eop_lag`, `mop_lag`).
|
/// The standard metrics provided by JPMaQS (`value`, `grading`, `eop_lag`, `mop_lag`).
|
||||||
pub const DEFAULT_JPMAQS_METRICS: [&str; 4] = ["value", "grading", "eop_lag", "mop_lag"];
|
pub const DEFAULT_JPMAQS_METRICS: [&str; 4] = ["value", "grading", "eop_lag", "mop_lag"];
|
||||||
@ -6,7 +9,6 @@ pub const DEFAULT_JPMAQS_METRICS: [&str; 4] = ["value", "grading", "eop_lag", "m
|
|||||||
/// The required columns for a Quantamental DataFrame.
|
/// The required columns for a Quantamental DataFrame.
|
||||||
pub const QDF_INDEX_COLUMNS: [&str; 3] = ["real_date", "cid", "xcat"];
|
pub const QDF_INDEX_COLUMNS: [&str; 3] = ["real_date", "cid", "xcat"];
|
||||||
|
|
||||||
|
|
||||||
/// Check if a DataFrame is a quantamental DataFrame.
|
/// Check if a DataFrame is a quantamental DataFrame.
|
||||||
/// A standard Quantamental DataFrame has the following columns:
|
/// A standard Quantamental DataFrame has the following columns:
|
||||||
/// - `real_date`: Date column as a date type
|
/// - `real_date`: Date column as a date type
|
||||||
@ -15,49 +17,302 @@ pub const QDF_INDEX_COLUMNS: [&str; 3] = ["real_date", "cid", "xcat"];
|
|||||||
///
|
///
|
||||||
/// Additionally, the DataFrame should have atleast 1 more column.
|
/// Additionally, the DataFrame should have atleast 1 more column.
|
||||||
/// Typically, this is one (or more) of the default JPMaQS metics.
|
/// Typically, this is one (or more) of the default JPMaQS metics.
|
||||||
pub fn is_quantamental_dataframe(df: &DataFrame) -> bool {
|
pub fn check_quantamental_dataframe(df: &DataFrame) -> Result<(), Box<dyn Error>> {
|
||||||
let columns = df
|
let expected_cols = ["real_date", "cid", "xcat"];
|
||||||
.get_column_names()
|
let expected_dtype = [DataType::Date, DataType::String, DataType::String];
|
||||||
.iter()
|
for (col, dtype) in expected_cols.iter().zip(expected_dtype.iter()) {
|
||||||
.map(|s| s.as_str())
|
let col = df.column(col);
|
||||||
.collect::<Vec<&str>>();
|
if col.is_err() {
|
||||||
let has_idx_columns = QDF_INDEX_COLUMNS.iter().all(|col| columns.contains(col));
|
return Err(format!("Column {:?} not found", col).into());
|
||||||
if !has_idx_columns {
|
}
|
||||||
return false;
|
let col = col?;
|
||||||
|
if col.dtype() != dtype {
|
||||||
|
return Err(format!("Column {:?} has wrong dtype", col).into());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
let real_date_col = df.select(["real_date"]);
|
Ok(())
|
||||||
match real_date_col {
|
}
|
||||||
Ok(_) => {}
|
|
||||||
Err(_) => return false,
|
/// Check if a DataFrame is a quantamental DataFrame.
|
||||||
};
|
/// Returns true if the DataFrame is a quantamental DataFrame, false otherwise.
|
||||||
|
/// Uses the `check_quantamental_dataframe` function to check if the DataFrame is a quantamental DataFrame.
|
||||||
let is_date_dtype = real_date_col
|
pub fn is_quantamental_dataframe(df: &DataFrame) -> bool {
|
||||||
.unwrap()
|
check_quantamental_dataframe(df).is_ok()
|
||||||
.dtypes()
|
}
|
||||||
.iter()
|
|
||||||
.all(|dtype| dtype == &DataType::Date);
|
pub fn sort_qdf_columns(qdf: &mut DataFrame) -> Result<(), Box<dyn Error>> {
|
||||||
|
let index_columns = ["real_date", "cid", "xcat"];
|
||||||
if !is_date_dtype {
|
let known_metrics = ["value", "grading", "eop_lag", "mop_lag"];
|
||||||
return false;
|
|
||||||
}
|
let df_columns = qdf
|
||||||
|
.get_column_names()
|
||||||
let cid_col = df.select(["cid"]);
|
.into_iter()
|
||||||
match cid_col {
|
.map(|s| s.clone().into_string())
|
||||||
Ok(_) => {}
|
.collect::<Vec<String>>();
|
||||||
Err(_) => return false,
|
|
||||||
};
|
let mut unknown_metrics: Vec<String> = df_columns
|
||||||
|
.iter()
|
||||||
let xcat_col = df.select(["xcat"]);
|
.filter(|&m| !known_metrics.contains(&m.as_str()))
|
||||||
match xcat_col {
|
.filter(|&m| !index_columns.contains(&m.as_str()))
|
||||||
Ok(_) => {}
|
.cloned()
|
||||||
Err(_) => return false,
|
.collect();
|
||||||
};
|
|
||||||
|
let mut new_columns: Vec<String> = vec![];
|
||||||
// has atleast 1 more column
|
new_columns.extend(index_columns.iter().map(|s| s.to_string()));
|
||||||
let has_other_columns = columns.len() > 3;
|
for &colname in &known_metrics {
|
||||||
if !has_other_columns {
|
if df_columns.contains(&colname.into()) {
|
||||||
return false;
|
new_columns.push(colname.to_string());
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return true;
|
|
||||||
|
unknown_metrics.sort();
|
||||||
|
new_columns.extend(unknown_metrics);
|
||||||
|
*qdf = qdf
|
||||||
|
.select(new_columns.clone())
|
||||||
|
.expect("Failed to select columns");
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load_quantamental_dataframe(
|
||||||
|
file_path: &str,
|
||||||
|
) -> Result<DataFrame, Box<dyn std::error::Error>> {
|
||||||
|
// get the file base name
|
||||||
|
let file_name = std::path::Path::new(file_path)
|
||||||
|
.file_stem()
|
||||||
|
.unwrap()
|
||||||
|
.to_str()
|
||||||
|
.unwrap()
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
// if filename does not have _ then it is not a Quantamental DataFrame
|
||||||
|
if !file_name.contains('_') {
|
||||||
|
return Err("The file name must be in the format `cid_xcat.csv` (`ticker.csv`)".into());
|
||||||
|
}
|
||||||
|
|
||||||
|
let ticker = file_name.split('.').collect::<Vec<&str>>()[0];
|
||||||
|
let (cid, xcat) = split_ticker(ticker)?;
|
||||||
|
|
||||||
|
let mut df = CsvReadOptions::default()
|
||||||
|
.try_into_reader_with_file_path(Some(file_path.into()))
|
||||||
|
.unwrap()
|
||||||
|
.finish()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let err = "The dataframe must have a `real_date` column and atleast 1 additional value column";
|
||||||
|
if df.column("real_date").is_err() || df.width() < 2 {
|
||||||
|
return Err(err.into());
|
||||||
|
}
|
||||||
|
let real_date_col = df
|
||||||
|
.column("real_date".into())
|
||||||
|
.unwrap()
|
||||||
|
.cast(&DataType::Date)?;
|
||||||
|
|
||||||
|
df.with_column(real_date_col)?;
|
||||||
|
df.with_column(Series::new("cid".into(), vec![cid; df.height()]))?;
|
||||||
|
df.with_column(Series::new("xcat".into(), vec![xcat; df.height()]))?;
|
||||||
|
|
||||||
|
sort_qdf_columns(&mut df)?;
|
||||||
|
|
||||||
|
Ok(df)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get intersecting cross-sections from a DataFrame.
|
||||||
|
pub fn get_intersecting_cids(
|
||||||
|
df: &DataFrame,
|
||||||
|
xcats: &Option<Vec<String>>,
|
||||||
|
) -> Result<Vec<String>, Box<dyn Error>> {
|
||||||
|
let rel_xcats = xcats
|
||||||
|
.clone()
|
||||||
|
.unwrap_or_else(|| get_unique_xcats(df).unwrap());
|
||||||
|
let found_tickers = get_unique_tickers(df)?;
|
||||||
|
let found_cids = get_unique_cids(df)?;
|
||||||
|
let keep_cids = get_intersecting_cids_str_func(&found_cids, &rel_xcats, &found_tickers);
|
||||||
|
Ok(keep_cids)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get intersecting tickers from a DataFrame.
|
||||||
|
#[allow(dead_code)]
|
||||||
|
fn get_tickers_interesecting_on_xcat(
|
||||||
|
df: &DataFrame,
|
||||||
|
xcats: &Option<Vec<String>>,
|
||||||
|
) -> Result<Vec<String>, Box<dyn Error>> {
|
||||||
|
let rel_cids = get_intersecting_cids(df, xcats)?;
|
||||||
|
let rel_xcats = xcats
|
||||||
|
.clone()
|
||||||
|
.unwrap_or_else(|| get_unique_xcats(df).unwrap());
|
||||||
|
let rel_cids_str: Vec<&str> = rel_cids.iter().map(AsRef::as_ref).collect();
|
||||||
|
let rel_xcats_str: Vec<&str> = rel_xcats.iter().map(AsRef::as_ref).collect();
|
||||||
|
Ok(create_interesecting_tickers(&rel_cids_str, &rel_xcats_str))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the unique tickers from a Quantamental DataFrame.
|
||||||
|
pub fn get_ticker_column_for_quantamental_dataframe(
|
||||||
|
df: &DataFrame,
|
||||||
|
) -> Result<Column, Box<dyn Error>> {
|
||||||
|
check_quantamental_dataframe(df)?;
|
||||||
|
let mut ticker_df =
|
||||||
|
DataFrame::new(vec![df.column("cid")?.clone(), df.column("xcat")?.clone()])?
|
||||||
|
.lazy()
|
||||||
|
.select([concat_str([col("cid"), col("xcat")], "_", true)])
|
||||||
|
.collect()?;
|
||||||
|
|
||||||
|
Ok(ticker_df
|
||||||
|
.rename("cid", "ticker".into())
|
||||||
|
.unwrap()
|
||||||
|
.column("ticker")
|
||||||
|
.unwrap()
|
||||||
|
.clone())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the unique tickers from a DataFrame.
|
||||||
|
/// Returns a Vec of unique tickers.
|
||||||
|
pub fn get_unique_tickers(df: &DataFrame) -> Result<Vec<String>, Box<dyn Error>> {
|
||||||
|
let ticker_col = get_ticker_column_for_quantamental_dataframe(df)?;
|
||||||
|
_get_unique_strs_from_str_column_object(&ticker_col)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the unique cross-sectional identifiers (`cids`) from a DataFrame.
|
||||||
|
pub fn get_unique_cids(df: &DataFrame) -> Result<Vec<String>, Box<dyn Error>> {
|
||||||
|
check_quantamental_dataframe(df)?;
|
||||||
|
get_unique_from_str_column(df, "cid")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the unique extended categories (`xcats`) from a DataFrame.
|
||||||
|
pub fn get_unique_xcats(df: &DataFrame) -> Result<Vec<String>, Box<dyn Error>> {
|
||||||
|
check_quantamental_dataframe(df)?;
|
||||||
|
get_unique_from_str_column(df, "xcat")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Filter a dataframe based on the given parameters.
|
||||||
|
/// - `cids`: Filter by cross-sectional identifiers
|
||||||
|
/// - `xcats`: Filter by extended categories
|
||||||
|
/// - `metrics`: Filter by metrics
|
||||||
|
/// - `start`: Filter by start date
|
||||||
|
/// - `end`: Filter by end date
|
||||||
|
/// - `intersect`: If true, intersect only return `cids` that are present for all `xcats`.
|
||||||
|
/// Returns a new DataFrame with the filtered data, without modifying the original DataFrame.
|
||||||
|
/// If no filters are provided, the original DataFrame is returned.
|
||||||
|
pub fn reduce_dataframe(
|
||||||
|
df: &DataFrame,
|
||||||
|
cids: Option<Vec<String>>,
|
||||||
|
xcats: Option<Vec<String>>,
|
||||||
|
metrics: Option<Vec<String>>,
|
||||||
|
start: Option<&str>,
|
||||||
|
end: Option<&str>,
|
||||||
|
intersect: bool,
|
||||||
|
) -> Result<DataFrame, Box<dyn Error>> {
|
||||||
|
check_quantamental_dataframe(df)?;
|
||||||
|
|
||||||
|
let mut new_df: DataFrame = df.clone();
|
||||||
|
|
||||||
|
let ticker_col: Column = get_ticker_column_for_quantamental_dataframe(&new_df)?;
|
||||||
|
|
||||||
|
// if cids is not provided, get all unique cids
|
||||||
|
let u_cids: Vec<String> = get_unique_cids(&new_df)?;
|
||||||
|
let u_xcats: Vec<String> = get_unique_xcats(&new_df)?;
|
||||||
|
let u_tickers: Vec<String> = _get_unique_strs_from_str_column_object(&ticker_col)?;
|
||||||
|
|
||||||
|
let specified_cids: Vec<String> = cids.unwrap_or_else(|| u_cids.clone());
|
||||||
|
let specified_xcats: Vec<String> = xcats.unwrap_or_else(|| u_xcats.clone());
|
||||||
|
let specified_metrics: Vec<String> = metrics.unwrap_or_else(|| {
|
||||||
|
DEFAULT_JPMAQS_METRICS
|
||||||
|
.iter()
|
||||||
|
.map(|&s| s.to_string())
|
||||||
|
.collect()
|
||||||
|
});
|
||||||
|
let specified_tickers: Vec<String> = create_interesecting_tickers(
|
||||||
|
&specified_cids
|
||||||
|
.iter()
|
||||||
|
.map(AsRef::as_ref)
|
||||||
|
.collect::<Vec<&str>>(),
|
||||||
|
&specified_xcats
|
||||||
|
.iter()
|
||||||
|
.map(AsRef::as_ref)
|
||||||
|
.collect::<Vec<&str>>(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let keep_tickers: Vec<String> = match intersect {
|
||||||
|
true => get_intersecting_cids_str_func(&u_cids, &u_xcats, &u_tickers),
|
||||||
|
false => specified_tickers.clone(),
|
||||||
|
};
|
||||||
|
let kticks: Vec<&str> = keep_tickers
|
||||||
|
.iter()
|
||||||
|
.map(AsRef::as_ref)
|
||||||
|
.collect::<Vec<&str>>();
|
||||||
|
|
||||||
|
// Create a boolean mask to filter rows based on the tickers
|
||||||
|
let mut mask = vec![false; ticker_col.len()];
|
||||||
|
for (i, ticker) in ticker_col.str()?.iter().enumerate() {
|
||||||
|
if let Some(t) = ticker {
|
||||||
|
if kticks.contains(&t) {
|
||||||
|
mask[i] = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let mask = BooleanChunked::from_slice("mask".into(), &mask);
|
||||||
|
new_df = new_df.filter(&mask)?;
|
||||||
|
|
||||||
|
// Apply date filtering if `start` or `end` is provided
|
||||||
|
if let Some(start_date) = start {
|
||||||
|
new_df = new_df
|
||||||
|
.lazy()
|
||||||
|
.filter(col("real_date").gt_eq(start_date))
|
||||||
|
.collect()?;
|
||||||
|
}
|
||||||
|
if let Some(end_date) = end {
|
||||||
|
new_df = new_df
|
||||||
|
.lazy()
|
||||||
|
.filter(col("real_date").lt_eq(end_date))
|
||||||
|
.collect()?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter based on metrics if provided
|
||||||
|
assert!(specified_metrics.len() > 0);
|
||||||
|
|
||||||
|
// remove columns that are not in the specified metrics
|
||||||
|
let mut cols_to_remove = Vec::new();
|
||||||
|
for col in new_df.get_column_names() {
|
||||||
|
if !specified_metrics.contains(&col.to_string()) {
|
||||||
|
cols_to_remove.push(col);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
new_df = new_df.drop_many(
|
||||||
|
cols_to_remove
|
||||||
|
.iter()
|
||||||
|
.map(|s| s.to_string())
|
||||||
|
.collect::<Vec<String>>(),
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(new_df)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Update a Quantamental DataFrame with new data.
|
||||||
|
/// - `df`: The original DataFrame
|
||||||
|
/// - `df_add`: The new DataFrame to add
|
||||||
|
///
|
||||||
|
pub fn update_dataframe(
|
||||||
|
df: &DataFrame,
|
||||||
|
df_add: &DataFrame,
|
||||||
|
// xcat_replace: Option<&str>,
|
||||||
|
) -> Result<DataFrame, Box<dyn Error>> {
|
||||||
|
check_quantamental_dataframe(df)?;
|
||||||
|
check_quantamental_dataframe(df_add)?;
|
||||||
|
if df.is_empty() {
|
||||||
|
return Ok(df_add.clone());
|
||||||
|
} else if df_add.is_empty() {
|
||||||
|
return Ok(df.clone());
|
||||||
|
};
|
||||||
|
|
||||||
|
// vstack and drop duplicates keeping last
|
||||||
|
let mut new_df = df.vstack(df_add)?;
|
||||||
|
// help?
|
||||||
|
let idx_cols_vec = QDF_INDEX_COLUMNS
|
||||||
|
.iter()
|
||||||
|
.map(|s| s.to_string())
|
||||||
|
.collect::<Vec<String>>();
|
||||||
|
|
||||||
|
new_df = new_df.unique_stable(Some(&idx_cols_vec), UniqueKeepStrategy::Last, None)?;
|
||||||
|
|
||||||
|
Ok(new_df)
|
||||||
}
|
}
|
||||||
|
83
src/utils/misc.rs
Normal file
83
src/utils/misc.rs
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
use polars::prelude::*;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::error::Error;
|
||||||
|
|
||||||
|
pub fn split_ticker(ticker: &str) -> Result<(&str, &str), Box<dyn Error>> {
|
||||||
|
// split by the first underscore character. return the first and second parts.
|
||||||
|
let parts: Vec<&str> = ticker.splitn(2, '_').collect();
|
||||||
|
if parts.len() != 2 {
|
||||||
|
return Err("Invalid ticker format".into());
|
||||||
|
}
|
||||||
|
Ok((parts[0], parts[1]))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub fn get_cid(ticker: &str) -> Result<&str, Box<dyn Error>> {
|
||||||
|
split_ticker(ticker).map(|(cid, _)| cid)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub fn get_xcat(ticker: &str) -> Result<&str, Box<dyn Error>> {
|
||||||
|
split_ticker(ticker).map(|(_, xcat)| xcat)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn create_ticker(cid: &str, xcat: &str) -> String {
|
||||||
|
format!("{}_{}", cid, xcat)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn create_interesecting_tickers(cids: &[&str], xcats: &[&str]) -> Vec<String> {
|
||||||
|
let mut tickers = Vec::new();
|
||||||
|
for cid in cids {
|
||||||
|
for xcat in xcats {
|
||||||
|
tickers.push(create_ticker(cid, xcat));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tickers
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Backed function to get unique strings from a string column object.
|
||||||
|
pub fn _get_unique_strs_from_str_column_object(
|
||||||
|
col: &Column,
|
||||||
|
) -> Result<Vec<String>, Box<dyn Error>> {
|
||||||
|
let res = col
|
||||||
|
.unique()?
|
||||||
|
.sort(SortOptions::default())?
|
||||||
|
.drop_nulls()
|
||||||
|
.str()?
|
||||||
|
.iter()
|
||||||
|
.map(|x| x.unwrap_or_default().to_string())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the unique values from a string column in a DataFrame.
|
||||||
|
pub fn get_unique_from_str_column(
|
||||||
|
df: &DataFrame,
|
||||||
|
col: &str,
|
||||||
|
) -> Result<Vec<String>, Box<dyn Error>> {
|
||||||
|
_get_unique_strs_from_str_column_object(&df.column(col).unwrap())
|
||||||
|
}
|
||||||
|
pub fn get_intersecting_cids_str_func(
|
||||||
|
cids: &Vec<String>,
|
||||||
|
xcats: &Vec<String>,
|
||||||
|
found_tickers: &Vec<String>,
|
||||||
|
) -> Vec<String> {
|
||||||
|
let mut keep_cids = cids.clone();
|
||||||
|
// make a hashmap of cids to xcats
|
||||||
|
let mut cid_xcat_map = HashMap::new();
|
||||||
|
for ticker in found_tickers {
|
||||||
|
let (cid, xcat) = split_ticker(&ticker).unwrap();
|
||||||
|
cid_xcat_map.insert(cid.to_string(), xcat.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
// filter out cids that are not present in all xcats
|
||||||
|
for (cid, xcats_for_cid) in cid_xcat_map.iter() {
|
||||||
|
// if the all xcats are not present, remove the cid
|
||||||
|
if !xcats.iter().all(|xcat| xcats_for_cid.contains(xcat)) {
|
||||||
|
keep_cids.retain(|x| x != cid);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
keep_cids
|
||||||
|
}
|
@ -1 +1,2 @@
|
|||||||
pub mod dftools;
|
pub mod dftools;
|
||||||
|
pub mod misc;
|
Loading…
x
Reference in New Issue
Block a user