msyrs/src/panel/historic_vol.rs

391 lines
13 KiB
Rust

use crate::utils::misc::*;
use crate::utils::qdf::pivots::*;
use crate::utils::qdf::reduce_df::*;
use chrono::NaiveDate;
use ndarray::{s, Array, Array1, Zip};
use polars::prelude::*;
use polars::series::Series; // Series struct
// use polars::time::Duration;
/// Returns the annualization factor for 252 trading days.
/// (SQRT(252))
#[allow(dead_code)]
fn annualization_factor() -> f64 {
252f64.sqrt()
}
#[allow(dead_code)]
fn expo_weights(lback_periods: usize, half_life: f64) -> Array1<f64> {
// Calculates exponential series weights for finite horizon, normalized to 1.
let decf = 2f64.powf(-1.0 / half_life);
let mut weights = Array::from_iter(
(0..lback_periods)
.map(|ii| (lback_periods - ii - 1) as f64)
.map(|exponent| (decf.powf(exponent)) * (1.0 - decf)),
);
weights /= weights.sum();
weights
}
#[allow(dead_code)]
fn expo_std(x: &Array1<f64>, w: &Array1<f64>, remove_zeros: bool) -> f64 {
assert_eq!(x.len(), w.len(), "weights and window must have same length");
let (filtered_x, filtered_w) = if remove_zeros {
let indices: Vec<usize> = x
.iter()
.enumerate()
.filter_map(|(i, &val)| if val != 0.0 { Some(i) } else { None })
.collect();
(
Array::from_iter(indices.iter().map(|&i| x[i])),
Array::from_iter(indices.iter().map(|&i| w[i])),
)
} else {
(x.clone(), w.clone())
};
let filtered_w = &filtered_w / filtered_w.sum();
Zip::from(&filtered_x)
.and(&filtered_w)
.fold(0.0, |acc, &x_val, &w_val| acc + w_val * x_val.abs())
}
#[allow(dead_code)]
fn flat_std(x: &Array1<f64>, remove_zeros: bool) -> f64 {
let filtered_x = if remove_zeros {
x.iter()
.filter(|&&val| val != 0.0)
.cloned()
.collect::<Array1<f64>>()
} else {
x.clone()
};
filtered_x.mapv(f64::abs).mean().unwrap_or(0.0)
}
#[allow(unused_variables)]
fn freq_daily_calc(
dfw: &DataFrame,
lback_periods: usize,
lback_method: &str,
half_life: Option<f64>,
remove_zeros: bool,
nan_tolerance: f64,
) -> Result<DataFrame, Box<dyn std::error::Error>> {
if lback_method == "xma" {
assert!(
half_life.is_some(),
"If lback_method is 'xma', half_life must be provided."
);
}
let idx = UInt32Chunked::from_vec(
"idx".into(),
(lback_periods - 1..dfw.height())
.map(|x| x as u32)
.collect(),
);
let real_date_col = dfw.column("real_date")?.take(&idx)?;
let mut new_df = DataFrame::new(vec![real_date_col])?;
for col_name in dfw.get_column_names() {
if col_name == "real_date" {
continue;
}
let series = dfw.column(col_name)?;
let values: Array1<f64> = series
.f64()?
.into_iter()
.map(|opt| opt.unwrap_or(0.0))
.collect();
let result_series = match lback_method {
"ma" => {
let mut result = Vec::new();
for i in (lback_periods - 1)..(values.len()) {
let window = values.slice(s![i + 1 - lback_periods..=i]);
let std = flat_std(&window.to_owned(), remove_zeros);
result.push(std);
}
Series::new(col_name.clone(), result)
}
"xma" => {
let half_life = half_life.unwrap();
let weights = expo_weights(lback_periods, half_life);
let mut result = Vec::new();
// for i in 0..(values.len() - lback_periods + 1) {
for i in (lback_periods - 1)..(values.len()) {
let window = values.slice(s![i + 1 - lback_periods..=i]);
let std = expo_std(&window.to_owned(), &weights, remove_zeros);
result.push(std);
}
Series::new(col_name.clone(), result)
}
_ => return Err("Invalid lookback method.".into()),
};
new_df.with_column(result_series)?;
}
Ok(new_df)
}
#[allow(unused_variables)]
fn freq_period_calc(
dfw: &DataFrame,
lback_periods: usize,
lback_method: &str,
half_life: Option<f64>,
remove_zeros: bool,
nan_tolerance: f64,
est_freq: &str,
) -> Result<DataFrame, Box<dyn std::error::Error>> {
if lback_method == "xma" {
assert!(
half_life.is_some(),
"If lback_method is 'xma', half_life must be provided."
);
}
println!("Calculating historic volatility with the following parameters:");
println!("lback_periods: {:?}, lback_method: {:?}, half_life: {:?}, remove_zeros: {:?}, nan_tolerance: {:?}, period: {:?}", lback_periods, lback_method, half_life, remove_zeros, nan_tolerance, est_freq);
let period_indices: Vec<usize> = get_period_indices(dfw, est_freq)?;
// new_df = dfw['real_date'].iloc[period_indices].copy()
let idx = UInt32Chunked::from_vec(
"idx".into(),
period_indices.iter().map(|&x| x as u32).collect(),
);
let real_date_col = dfw.column("real_date")?.take(&idx)?;
let mut new_df = DataFrame::new(vec![real_date_col])?;
for col_name in dfw.get_column_names() {
if col_name == "real_date" {
continue;
}
let series = dfw.column(col_name)?;
let values: Array1<f64> = series
.f64()?
.into_iter()
.map(|opt| opt.unwrap_or(0.0))
.collect();
let result_series = match lback_method {
"ma" => {
let mut result = Vec::new();
for &i in &period_indices {
if i >= lback_periods - 1 {
let window = values.slice(s![i + 1 - lback_periods..=i]);
let std = flat_std(&window.to_owned(), remove_zeros);
let std = std * annualization_factor();
result.push(std);
} else {
result.push(f64::NAN);
}
}
Series::new(col_name.clone(), result)
}
"xma" => {
let half_life = half_life.unwrap();
let weights = expo_weights(lback_periods, half_life);
let mut result = Vec::new();
for &i in &period_indices {
if i >= lback_periods - 1 {
let window = values.slice(s![i + 1 - lback_periods..=i]);
let std = expo_std(&window.to_owned(), &weights, remove_zeros);
let std = std * annualization_factor();
result.push(std);
} else {
result.push(f64::NAN);
}
}
Series::new(col_name.clone(), result)
}
_ => return Err("Invalid lookback method.".into()),
};
println!(
"Successfully calculated result_series for column: {:?}",
col_name
);
new_df.with_column(result_series)?;
}
Ok(new_df)
}
pub fn get_bdates_from_col_hv(
dfw: &DataFrame,
est_freq: &str,
) -> Result<Series, Box<dyn std::error::Error>> {
let date_series = dfw.column("real_date")?.as_series().unwrap();
Ok(get_bdates_from_col(date_series, est_freq)?)
}
pub fn get_period_indices_hv(
dfw: &DataFrame,
est_freq: &str,
) -> Result<Vec<usize>, Box<dyn std::error::Error>> {
get_period_indices(dfw, est_freq)
}
fn get_period_indices(
dfw: &DataFrame,
est_freq: &str,
) -> Result<Vec<usize>, Box<dyn std::error::Error>> {
// let date_series: &Logical<DateType, Int32Type> = dfw.column("real_date")?.date()?;
let date_series = dfw.column("real_date")?.as_series().unwrap();
let mut indices = Vec::new();
let bdates: Series = get_bdates_from_col(date_series, est_freq)?;
for bdate in bdates.iter() {
if let Some(index) = date_series.iter().position(|date| date == bdate) {
indices.push(index);
}
}
Ok(indices)
}
/// Calculate historic volatility.
/// # Arguments:
/// - `df`: A Quantamental DataFrame.
/// - `xcat`: The category to calculate the historic volatility for.
/// - `cids`: A list of cross-sections. If none are provided, all cross-sections available
/// in the DataFrame will be used.
/// - `lback_periods`: The number of lookback periods to use for the calculation.
/// - `lback_method`: The method to use for the lookback period calculation. Options are
/// 'ma' (moving average) and 'xma' (exponential moving average).
/// - `half_life`: The half-life of the exponential weighting function.
/// - `start`: Only include data after this date. Defaults to the earliest date available.
/// - `end`: Only include data before this date. Defaults to the latest date available.
/// - `est_freq`: The frequency of the data. Defaults to 'D' (daily). Options are 'D' (daily),
/// 'W' (weekly), 'M' (monthly), and 'Q' (quarterly).
/// - `remove_zeros`: Whether to remove zero values from the calculation. Defaults to False.
/// - `postfix`: A string to append to XCAT of the result series.
/// - `nan_tolerance`: The maximum proportion of NaN values allowed in the calculation.
///
#[allow(unused_variables)]
pub fn historic_vol(
df: polars::prelude::DataFrame,
xcat: String,
cids: Option<Vec<String>>,
lback_periods: Option<usize>,
lback_method: Option<String>,
half_life: Option<f64>,
start: Option<String>,
end: Option<String>,
est_freq: Option<String>,
remove_zeros: Option<bool>,
postfix: Option<String>,
nan_tolerance: Option<f64>,
) -> Result<DataFrame, Box<dyn std::error::Error>> {
println!("Calculating historic volatility with the following parameters:");
println!("xcat: {:?},\ncids: {:?},\nlback_periods: {:?},lback_method: {:?},\nhalf_life: {:?},\nstart: {:?},\nend: {:?},\nest_freq: {:?},\nremove_zeros: {:?},\npostfix: {:?},\nnan_tolerance: {:?}", xcat, cids, lback_periods,lback_method, half_life, start, end, est_freq, remove_zeros, postfix, nan_tolerance);
let rdf = reduce_dataframe(
df.clone(),
cids,
Some(vec![xcat]),
None,
start.clone(),
end.clone(),
false,
)?;
let mut dfw = pivot_dataframe_by_ticker(rdf, Some("value".to_string()))?;
println!("Successfully pivoted the DataFrame.");
let lback_periods = lback_periods.unwrap_or(20);
let lback_method = lback_method.unwrap_or("ma".to_string());
let half_life = half_life;
println!("Successfully got lback_periods, lback_method, and half_life.");
let start = start.unwrap_or(dfw.column("real_date")?.date()?.min().unwrap().to_string());
let end = end.unwrap_or(dfw.column("real_date")?.date()?.max().unwrap().to_string());
println!("Successfully got start and end dates.");
let est_freq = est_freq.unwrap_or("D".to_string());
println!("Successfully got est_freq.");
let remove_zeros = remove_zeros.unwrap_or(false);
println!("Successfully got remove_zeros.");
let postfix = postfix.unwrap_or("_HISTVOL".to_string());
println!("Successfully got postfix.");
let nan_tolerance = nan_tolerance.unwrap_or(0.25);
println!("Successfully got nan_tolerance.");
let (dfw_start_date, dfw_end_date) =
crate::utils::misc::get_min_max_real_dates(&dfw, "real_date")?;
println!("Successfully got min and max real dates.");
let (start_date, end_date) = (
NaiveDate::parse_from_str(&start, "%Y-%m-%d").unwrap_or_else(|_| dfw_start_date),
NaiveDate::parse_from_str(&end, "%Y-%m-%d").unwrap_or_else(|_| dfw_end_date),
);
println!("Successfully parsed start and end dates.");
dfw = dfw
.lazy()
.filter(
col("real_date")
.lt_eq(lit(end_date))
.alias("real_date")
.into(),
)
.filter(
col("real_date")
.gt_eq(lit(start_date))
.alias("real_date")
.into(),
)
.collect()?;
println!("Successfully filtered the DataFrame.");
println!("Successfully got period.");
let mut dfw = match est_freq.as_str() {
"X" => freq_daily_calc(
&dfw,
lback_periods,
&lback_method,
half_life,
remove_zeros,
nan_tolerance,
)?,
_ => freq_period_calc(
&dfw,
lback_periods,
&lback_method,
half_life,
remove_zeros,
nan_tolerance,
&est_freq,
)?,
};
// rename each column to include the postfix
for ic in 0..dfw.get_column_names().len() {
let col_name = dfw.get_column_names()[ic].to_string();
if col_name == "real_date" {
continue;
}
let new_col_name = format!("{}{}", col_name, postfix);
dfw.rename(&col_name, new_col_name.into())?;
}
dfw = pivot_wide_dataframe_to_qdf(dfw, Some("value".to_string()))?;
Ok(dfw)
}