Refactor historic volatility calculations to improve data handling and add postfix renaming for result columns

This commit is contained in:
Palash Tyagi 2025-04-05 16:31:04 +01:00
parent e4412c638b
commit 19ff090508

View File

@ -1,7 +1,10 @@
use crate::utils::qdf::pivots::*;
use crate::utils::qdf::reduce_df::*;
use chrono::{Datelike, NaiveDate};
use ndarray::{s, Array, Array1, Zip};
use polars::prelude::*;
use polars::series::Series; // Series struct
// use polars::time::Duration;
/// Returns the annualization factor for 252 trading days.
@ -74,9 +77,19 @@ fn freq_daily_calc(
);
}
let mut new_df = dfw.clone();
let idx = UInt32Chunked::from_vec(
"idx".into(),
(lback_periods - 1..dfw.height())
.map(|x| x as u32)
.collect(),
);
let real_date_col = dfw.column("real_date")?.take(&idx)?;
let mut new_df = DataFrame::new(vec![real_date_col])?;
for col_name in dfw.get_column_names() {
if col_name == "real_date" {
continue;
}
let series = dfw.column(col_name)?;
let values: Array1<f64> = series
.f64()?
@ -87,8 +100,8 @@ fn freq_daily_calc(
let result_series = match lback_method {
"ma" => {
let mut result = Vec::new();
for i in 0..(values.len() - lback_periods + 1) {
let window = values.slice(s![i..i + lback_periods]);
for i in (lback_periods - 1)..(values.len()) {
let window = values.slice(s![i + 1 - lback_periods..=i]);
let std = flat_std(&window.to_owned(), remove_zeros);
result.push(std);
}
@ -98,8 +111,9 @@ fn freq_daily_calc(
let half_life = half_life.unwrap();
let weights = expo_weights(lback_periods, half_life);
let mut result = Vec::new();
for i in 0..(values.len() - lback_periods + 1) {
let window = values.slice(s![i..i + lback_periods]);
// for i in 0..(values.len() - lback_periods + 1) {
for i in (lback_periods - 1)..(values.len()) {
let window = values.slice(s![i + 1 - lback_periods..=i]);
let std = expo_std(&window.to_owned(), &weights, remove_zeros);
result.push(std);
}
@ -134,7 +148,15 @@ fn freq_period_calc(
println!("Calculating historic volatility with the following parameters:");
println!("lback_periods: {:?}, lback_method: {:?}, half_life: {:?}, remove_zeros: {:?}, nan_tolerance: {:?}, period: {:?}", lback_periods, lback_method, half_life, remove_zeros, nan_tolerance, period);
let mut new_df = dfw.clone();
let period_indices: Vec<usize> = get_period_indices(dfw, period)?;
// new_df = dfw['real_date'].iloc[period_indices].copy()
let idx = UInt32Chunked::from_vec(
"idx".into(),
period_indices.iter().map(|&x| x as u32).collect(),
);
let real_date_col = dfw.column("real_date")?.take(&idx)?;
let mut new_df = DataFrame::new(vec![real_date_col])?;
for col_name in dfw.get_column_names() {
if col_name == "real_date" {
@ -150,7 +172,6 @@ fn freq_period_calc(
let result_series = match lback_method {
"ma" => {
let mut result = Vec::new();
let period_indices = get_period_indices(dfw, period)?;
for &i in &period_indices {
if i >= lback_periods - 1 {
let window = values.slice(s![i + 1 - lback_periods..=i]);
@ -166,7 +187,6 @@ fn freq_period_calc(
let half_life = half_life.unwrap();
let weights = expo_weights(lback_periods, half_life);
let mut result = Vec::new();
let period_indices = get_period_indices(dfw, period)?;
for &i in &period_indices {
if i >= lback_periods - 1 {
let window = values.slice(s![i + 1 - lback_periods..=i]);
@ -180,7 +200,10 @@ fn freq_period_calc(
}
_ => return Err("Invalid lookback method.".into()),
};
println!("Successfully calculated result_series for column: {:?}", col_name);
println!(
"Successfully calculated result_series for column: {:?}",
col_name
);
new_df.with_column(result_series)?;
}
@ -273,7 +296,18 @@ pub fn historic_vol(
) -> Result<DataFrame, Box<dyn std::error::Error>> {
println!("Calculating historic volatility with the following parameters:");
println!("xcat: {:?},\ncids: {:?},\nlback_periods: {:?},lback_method: {:?},\nhalf_life: {:?},\nstart: {:?},\nend: {:?},\nest_freq: {:?},\nremove_zeros: {:?},\npostfix: {:?},\nnan_tolerance: {:?}", xcat, cids, lback_periods,lback_method, half_life, start, end, est_freq, remove_zeros, postfix, nan_tolerance);
let mut dfw = pivot_dataframe_by_ticker(df.clone(), Some("value".to_string()))?;
let rdf = reduce_dataframe(
df.clone(),
cids,
Some(vec![xcat]),
None,
start.clone(),
end.clone(),
false,
)?;
let mut dfw = pivot_dataframe_by_ticker(rdf, Some("value".to_string()))?;
println!("Successfully pivoted the DataFrame.");
@ -333,6 +367,7 @@ pub fn historic_vol(
println!("Successfully filtered the DataFrame.");
let period = match est_freq.as_str() {
"D" => "daily",
"W" => "weekly",
"M" => "monthly",
_ => return Err("Invalid frequency specified.".into()),
@ -340,7 +375,7 @@ pub fn historic_vol(
println!("Successfully got period.");
let dfw = match est_freq.as_str() {
let mut dfw = match est_freq.as_str() {
"D" => freq_daily_calc(
&dfw,
lback_periods,
@ -360,5 +395,17 @@ pub fn historic_vol(
)?,
};
// rename each column to include the postfix
for ic in 0..dfw.get_column_names().len() {
let col_name = dfw.get_column_names()[ic].to_string();
if col_name == "real_date" {
continue;
}
let new_col_name = format!("{}{}", col_name, postfix);
dfw.rename(&col_name, new_col_name.into())?;
}
dfw = pivot_wide_dataframe_to_qdf(dfw, Some("value".to_string()))?;
Ok(dfw)
}