mirror of
https://github.com/Magnus167/msyrs.git
synced 2025-08-20 18:40:00 +00:00
391 lines
13 KiB
Rust
391 lines
13 KiB
Rust
use crate::utils::misc::*;
|
|
use crate::utils::qdf::pivots::*;
|
|
use crate::utils::qdf::reduce_df::*;
|
|
use chrono::NaiveDate;
|
|
use ndarray::{s, Array, Array1, Zip};
|
|
use polars::prelude::*;
|
|
use polars::series::Series; // Series struct
|
|
|
|
// use polars::time::Duration;
|
|
|
|
/// Returns the annualization factor for 252 trading days.
|
|
/// (SQRT(252))
|
|
#[allow(dead_code)]
|
|
fn annualization_factor() -> f64 {
|
|
252f64.sqrt()
|
|
}
|
|
|
|
#[allow(dead_code)]
|
|
fn expo_weights(lback_periods: usize, half_life: f64) -> Array1<f64> {
|
|
// Calculates exponential series weights for finite horizon, normalized to 1.
|
|
let decf = 2f64.powf(-1.0 / half_life);
|
|
let mut weights = Array::from_iter(
|
|
(0..lback_periods)
|
|
.map(|ii| (lback_periods - ii - 1) as f64)
|
|
.map(|exponent| (decf.powf(exponent)) * (1.0 - decf)),
|
|
);
|
|
weights /= weights.sum();
|
|
weights
|
|
}
|
|
#[allow(dead_code)]
|
|
fn expo_std(x: &Array1<f64>, w: &Array1<f64>, remove_zeros: bool) -> f64 {
|
|
assert_eq!(x.len(), w.len(), "weights and window must have same length");
|
|
let (filtered_x, filtered_w) = if remove_zeros {
|
|
let indices: Vec<usize> = x
|
|
.iter()
|
|
.enumerate()
|
|
.filter_map(|(i, &val)| if val != 0.0 { Some(i) } else { None })
|
|
.collect();
|
|
(
|
|
Array::from_iter(indices.iter().map(|&i| x[i])),
|
|
Array::from_iter(indices.iter().map(|&i| w[i])),
|
|
)
|
|
} else {
|
|
(x.clone(), w.clone())
|
|
};
|
|
let filtered_w = &filtered_w / filtered_w.sum();
|
|
Zip::from(&filtered_x)
|
|
.and(&filtered_w)
|
|
.fold(0.0, |acc, &x_val, &w_val| acc + w_val * x_val.abs())
|
|
}
|
|
|
|
#[allow(dead_code)]
|
|
fn flat_std(x: &Array1<f64>, remove_zeros: bool) -> f64 {
|
|
let filtered_x = if remove_zeros {
|
|
x.iter()
|
|
.filter(|&&val| val != 0.0)
|
|
.cloned()
|
|
.collect::<Array1<f64>>()
|
|
} else {
|
|
x.clone()
|
|
};
|
|
filtered_x.mapv(f64::abs).mean().unwrap_or(0.0)
|
|
}
|
|
|
|
#[allow(unused_variables)]
|
|
fn freq_daily_calc(
|
|
dfw: &DataFrame,
|
|
lback_periods: usize,
|
|
lback_method: &str,
|
|
half_life: Option<f64>,
|
|
remove_zeros: bool,
|
|
nan_tolerance: f64,
|
|
) -> Result<DataFrame, Box<dyn std::error::Error>> {
|
|
if lback_method == "xma" {
|
|
assert!(
|
|
half_life.is_some(),
|
|
"If lback_method is 'xma', half_life must be provided."
|
|
);
|
|
}
|
|
|
|
let idx = UInt32Chunked::from_vec(
|
|
"idx".into(),
|
|
(lback_periods - 1..dfw.height())
|
|
.map(|x| x as u32)
|
|
.collect(),
|
|
);
|
|
let real_date_col = dfw.column("real_date")?.take(&idx)?;
|
|
let mut new_df = DataFrame::new(vec![real_date_col])?;
|
|
|
|
for col_name in dfw.get_column_names() {
|
|
if col_name == "real_date" {
|
|
continue;
|
|
}
|
|
let series = dfw.column(col_name)?;
|
|
let values: Array1<f64> = series
|
|
.f64()?
|
|
.into_iter()
|
|
.map(|opt| opt.unwrap_or(0.0))
|
|
.collect();
|
|
|
|
let result_series = match lback_method {
|
|
"ma" => {
|
|
let mut result = Vec::new();
|
|
for i in (lback_periods - 1)..(values.len()) {
|
|
let window = values.slice(s![i + 1 - lback_periods..=i]);
|
|
let std = flat_std(&window.to_owned(), remove_zeros);
|
|
result.push(std);
|
|
}
|
|
Series::new(col_name.clone(), result)
|
|
}
|
|
"xma" => {
|
|
let half_life = half_life.unwrap();
|
|
let weights = expo_weights(lback_periods, half_life);
|
|
let mut result = Vec::new();
|
|
// for i in 0..(values.len() - lback_periods + 1) {
|
|
for i in (lback_periods - 1)..(values.len()) {
|
|
let window = values.slice(s![i + 1 - lback_periods..=i]);
|
|
let std = expo_std(&window.to_owned(), &weights, remove_zeros);
|
|
result.push(std);
|
|
}
|
|
Series::new(col_name.clone(), result)
|
|
}
|
|
_ => return Err("Invalid lookback method.".into()),
|
|
};
|
|
|
|
new_df.with_column(result_series)?;
|
|
}
|
|
|
|
Ok(new_df)
|
|
}
|
|
#[allow(unused_variables)]
|
|
|
|
fn freq_period_calc(
|
|
dfw: &DataFrame,
|
|
lback_periods: usize,
|
|
lback_method: &str,
|
|
half_life: Option<f64>,
|
|
remove_zeros: bool,
|
|
nan_tolerance: f64,
|
|
est_freq: &str,
|
|
) -> Result<DataFrame, Box<dyn std::error::Error>> {
|
|
if lback_method == "xma" {
|
|
assert!(
|
|
half_life.is_some(),
|
|
"If lback_method is 'xma', half_life must be provided."
|
|
);
|
|
}
|
|
|
|
println!("Calculating historic volatility with the following parameters:");
|
|
println!("lback_periods: {:?}, lback_method: {:?}, half_life: {:?}, remove_zeros: {:?}, nan_tolerance: {:?}, period: {:?}", lback_periods, lback_method, half_life, remove_zeros, nan_tolerance, est_freq);
|
|
|
|
let period_indices: Vec<usize> = get_period_indices(dfw, est_freq)?;
|
|
|
|
// new_df = dfw['real_date'].iloc[period_indices].copy()
|
|
let idx = UInt32Chunked::from_vec(
|
|
"idx".into(),
|
|
period_indices.iter().map(|&x| x as u32).collect(),
|
|
);
|
|
let real_date_col = dfw.column("real_date")?.take(&idx)?;
|
|
let mut new_df = DataFrame::new(vec![real_date_col])?;
|
|
|
|
for col_name in dfw.get_column_names() {
|
|
if col_name == "real_date" {
|
|
continue;
|
|
}
|
|
let series = dfw.column(col_name)?;
|
|
let values: Array1<f64> = series
|
|
.f64()?
|
|
.into_iter()
|
|
.map(|opt| opt.unwrap_or(0.0))
|
|
.collect();
|
|
|
|
let result_series = match lback_method {
|
|
"ma" => {
|
|
let mut result = Vec::new();
|
|
for &i in &period_indices {
|
|
if i >= lback_periods - 1 {
|
|
let window = values.slice(s![i + 1 - lback_periods..=i]);
|
|
let std = flat_std(&window.to_owned(), remove_zeros);
|
|
let std = std * annualization_factor();
|
|
result.push(std);
|
|
} else {
|
|
result.push(f64::NAN);
|
|
}
|
|
}
|
|
Series::new(col_name.clone(), result)
|
|
}
|
|
"xma" => {
|
|
let half_life = half_life.unwrap();
|
|
let weights = expo_weights(lback_periods, half_life);
|
|
let mut result = Vec::new();
|
|
for &i in &period_indices {
|
|
if i >= lback_periods - 1 {
|
|
let window = values.slice(s![i + 1 - lback_periods..=i]);
|
|
let std = expo_std(&window.to_owned(), &weights, remove_zeros);
|
|
let std = std * annualization_factor();
|
|
result.push(std);
|
|
} else {
|
|
result.push(f64::NAN);
|
|
}
|
|
}
|
|
Series::new(col_name.clone(), result)
|
|
}
|
|
_ => return Err("Invalid lookback method.".into()),
|
|
};
|
|
println!(
|
|
"Successfully calculated result_series for column: {:?}",
|
|
col_name
|
|
);
|
|
new_df.with_column(result_series)?;
|
|
}
|
|
|
|
Ok(new_df)
|
|
}
|
|
|
|
pub fn get_bdates_from_col_hv(
|
|
dfw: &DataFrame,
|
|
est_freq: &str,
|
|
) -> Result<Series, Box<dyn std::error::Error>> {
|
|
let date_series = dfw.column("real_date")?.as_series().unwrap();
|
|
Ok(get_bdates_from_col(date_series, est_freq)?)
|
|
}
|
|
|
|
pub fn get_period_indices_hv(
|
|
dfw: &DataFrame,
|
|
est_freq: &str,
|
|
) -> Result<Vec<usize>, Box<dyn std::error::Error>> {
|
|
get_period_indices(dfw, est_freq)
|
|
}
|
|
|
|
fn get_period_indices(
|
|
dfw: &DataFrame,
|
|
est_freq: &str,
|
|
) -> Result<Vec<usize>, Box<dyn std::error::Error>> {
|
|
// let date_series: &Logical<DateType, Int32Type> = dfw.column("real_date")?.date()?;
|
|
let date_series = dfw.column("real_date")?.as_series().unwrap();
|
|
let mut indices = Vec::new();
|
|
|
|
let bdates: Series = get_bdates_from_col(date_series, est_freq)?;
|
|
|
|
for bdate in bdates.iter() {
|
|
if let Some(index) = date_series.iter().position(|date| date == bdate) {
|
|
indices.push(index);
|
|
}
|
|
}
|
|
|
|
Ok(indices)
|
|
}
|
|
|
|
/// Calculate historic volatility.
|
|
/// # Arguments:
|
|
/// - `df`: A Quantamental DataFrame.
|
|
/// - `xcat`: The category to calculate the historic volatility for.
|
|
/// - `cids`: A list of cross-sections. If none are provided, all cross-sections available
|
|
/// in the DataFrame will be used.
|
|
/// - `lback_periods`: The number of lookback periods to use for the calculation.
|
|
/// - `lback_method`: The method to use for the lookback period calculation. Options are
|
|
/// 'ma' (moving average) and 'xma' (exponential moving average).
|
|
/// - `half_life`: The half-life of the exponential weighting function.
|
|
/// - `start`: Only include data after this date. Defaults to the earliest date available.
|
|
/// - `end`: Only include data before this date. Defaults to the latest date available.
|
|
/// - `est_freq`: The frequency of the data. Defaults to 'D' (daily). Options are 'D' (daily),
|
|
/// 'W' (weekly), 'M' (monthly), and 'Q' (quarterly).
|
|
/// - `remove_zeros`: Whether to remove zero values from the calculation. Defaults to False.
|
|
/// - `postfix`: A string to append to XCAT of the result series.
|
|
/// - `nan_tolerance`: The maximum proportion of NaN values allowed in the calculation.
|
|
///
|
|
#[allow(unused_variables)]
|
|
pub fn historic_vol(
|
|
df: polars::prelude::DataFrame,
|
|
xcat: String,
|
|
cids: Option<Vec<String>>,
|
|
lback_periods: Option<usize>,
|
|
lback_method: Option<String>,
|
|
half_life: Option<f64>,
|
|
start: Option<String>,
|
|
end: Option<String>,
|
|
est_freq: Option<String>,
|
|
remove_zeros: Option<bool>,
|
|
postfix: Option<String>,
|
|
nan_tolerance: Option<f64>,
|
|
) -> Result<DataFrame, Box<dyn std::error::Error>> {
|
|
println!("Calculating historic volatility with the following parameters:");
|
|
println!("xcat: {:?},\ncids: {:?},\nlback_periods: {:?},lback_method: {:?},\nhalf_life: {:?},\nstart: {:?},\nend: {:?},\nest_freq: {:?},\nremove_zeros: {:?},\npostfix: {:?},\nnan_tolerance: {:?}", xcat, cids, lback_periods,lback_method, half_life, start, end, est_freq, remove_zeros, postfix, nan_tolerance);
|
|
|
|
let rdf = reduce_dataframe(
|
|
df.clone(),
|
|
cids,
|
|
Some(vec![xcat]),
|
|
None,
|
|
start.clone(),
|
|
end.clone(),
|
|
false,
|
|
)?;
|
|
|
|
let mut dfw = pivot_dataframe_by_ticker(rdf, Some("value".to_string()))?;
|
|
|
|
println!("Successfully pivoted the DataFrame.");
|
|
|
|
let lback_periods = lback_periods.unwrap_or(20);
|
|
let lback_method = lback_method.unwrap_or("ma".to_string());
|
|
let half_life = half_life;
|
|
|
|
println!("Successfully got lback_periods, lback_method, and half_life.");
|
|
|
|
let start = start.unwrap_or(dfw.column("real_date")?.date()?.min().unwrap().to_string());
|
|
let end = end.unwrap_or(dfw.column("real_date")?.date()?.max().unwrap().to_string());
|
|
|
|
println!("Successfully got start and end dates.");
|
|
|
|
let est_freq = est_freq.unwrap_or("D".to_string());
|
|
|
|
println!("Successfully got est_freq.");
|
|
|
|
let remove_zeros = remove_zeros.unwrap_or(false);
|
|
|
|
println!("Successfully got remove_zeros.");
|
|
|
|
let postfix = postfix.unwrap_or("_HISTVOL".to_string());
|
|
|
|
println!("Successfully got postfix.");
|
|
|
|
let nan_tolerance = nan_tolerance.unwrap_or(0.25);
|
|
|
|
println!("Successfully got nan_tolerance.");
|
|
|
|
let (dfw_start_date, dfw_end_date) =
|
|
crate::utils::misc::get_min_max_real_dates(&dfw, "real_date")?;
|
|
println!("Successfully got min and max real dates.");
|
|
|
|
let (start_date, end_date) = (
|
|
NaiveDate::parse_from_str(&start, "%Y-%m-%d").unwrap_or_else(|_| dfw_start_date),
|
|
NaiveDate::parse_from_str(&end, "%Y-%m-%d").unwrap_or_else(|_| dfw_end_date),
|
|
);
|
|
println!("Successfully parsed start and end dates.");
|
|
|
|
dfw = dfw
|
|
.lazy()
|
|
.filter(
|
|
col("real_date")
|
|
.lt_eq(lit(end_date))
|
|
.alias("real_date")
|
|
.into(),
|
|
)
|
|
.filter(
|
|
col("real_date")
|
|
.gt_eq(lit(start_date))
|
|
.alias("real_date")
|
|
.into(),
|
|
)
|
|
.collect()?;
|
|
|
|
println!("Successfully filtered the DataFrame.");
|
|
|
|
println!("Successfully got period.");
|
|
|
|
let mut dfw = match est_freq.as_str() {
|
|
"X" => freq_daily_calc(
|
|
&dfw,
|
|
lback_periods,
|
|
&lback_method,
|
|
half_life,
|
|
remove_zeros,
|
|
nan_tolerance,
|
|
)?,
|
|
_ => freq_period_calc(
|
|
&dfw,
|
|
lback_periods,
|
|
&lback_method,
|
|
half_life,
|
|
remove_zeros,
|
|
nan_tolerance,
|
|
&est_freq,
|
|
)?,
|
|
};
|
|
|
|
// rename each column to include the postfix
|
|
for ic in 0..dfw.get_column_names().len() {
|
|
let col_name = dfw.get_column_names()[ic].to_string();
|
|
if col_name == "real_date" {
|
|
continue;
|
|
}
|
|
let new_col_name = format!("{}{}", col_name, postfix);
|
|
dfw.rename(&col_name, new_col_name.into())?;
|
|
}
|
|
|
|
dfw = pivot_wide_dataframe_to_qdf(dfw, Some("value".to_string()))?;
|
|
|
|
Ok(dfw)
|
|
}
|