From 19ff090508cadbc9232b0df6f0e7c8f60cb4b778 Mon Sep 17 00:00:00 2001 From: Palash Tyagi <23239946+Magnus167@users.noreply.github.com> Date: Sat, 5 Apr 2025 16:31:04 +0100 Subject: [PATCH] Refactor historic volatility calculations to improve data handling and add postfix renaming for result columns --- src/panel/historic_vol.rs | 69 ++++++++++++++++++++++++++++++++------- 1 file changed, 58 insertions(+), 11 deletions(-) diff --git a/src/panel/historic_vol.rs b/src/panel/historic_vol.rs index 01dcbe2..49e1be5 100644 --- a/src/panel/historic_vol.rs +++ b/src/panel/historic_vol.rs @@ -1,7 +1,10 @@ use crate::utils::qdf::pivots::*; +use crate::utils::qdf::reduce_df::*; use chrono::{Datelike, NaiveDate}; use ndarray::{s, Array, Array1, Zip}; use polars::prelude::*; +use polars::series::Series; // Series struct + // use polars::time::Duration; /// Returns the annualization factor for 252 trading days. @@ -74,9 +77,19 @@ fn freq_daily_calc( ); } - let mut new_df = dfw.clone(); + let idx = UInt32Chunked::from_vec( + "idx".into(), + (lback_periods - 1..dfw.height()) + .map(|x| x as u32) + .collect(), + ); + let real_date_col = dfw.column("real_date")?.take(&idx)?; + let mut new_df = DataFrame::new(vec![real_date_col])?; for col_name in dfw.get_column_names() { + if col_name == "real_date" { + continue; + } let series = dfw.column(col_name)?; let values: Array1 = series .f64()? @@ -87,8 +100,8 @@ fn freq_daily_calc( let result_series = match lback_method { "ma" => { let mut result = Vec::new(); - for i in 0..(values.len() - lback_periods + 1) { - let window = values.slice(s![i..i + lback_periods]); + for i in (lback_periods - 1)..(values.len()) { + let window = values.slice(s![i + 1 - lback_periods..=i]); let std = flat_std(&window.to_owned(), remove_zeros); result.push(std); } @@ -98,8 +111,9 @@ fn freq_daily_calc( let half_life = half_life.unwrap(); let weights = expo_weights(lback_periods, half_life); let mut result = Vec::new(); - for i in 0..(values.len() - lback_periods + 1) { - let window = values.slice(s![i..i + lback_periods]); + // for i in 0..(values.len() - lback_periods + 1) { + for i in (lback_periods - 1)..(values.len()) { + let window = values.slice(s![i + 1 - lback_periods..=i]); let std = expo_std(&window.to_owned(), &weights, remove_zeros); result.push(std); } @@ -134,7 +148,15 @@ fn freq_period_calc( println!("Calculating historic volatility with the following parameters:"); println!("lback_periods: {:?}, lback_method: {:?}, half_life: {:?}, remove_zeros: {:?}, nan_tolerance: {:?}, period: {:?}", lback_periods, lback_method, half_life, remove_zeros, nan_tolerance, period); - let mut new_df = dfw.clone(); + let period_indices: Vec = get_period_indices(dfw, period)?; + + // new_df = dfw['real_date'].iloc[period_indices].copy() + let idx = UInt32Chunked::from_vec( + "idx".into(), + period_indices.iter().map(|&x| x as u32).collect(), + ); + let real_date_col = dfw.column("real_date")?.take(&idx)?; + let mut new_df = DataFrame::new(vec![real_date_col])?; for col_name in dfw.get_column_names() { if col_name == "real_date" { @@ -150,7 +172,6 @@ fn freq_period_calc( let result_series = match lback_method { "ma" => { let mut result = Vec::new(); - let period_indices = get_period_indices(dfw, period)?; for &i in &period_indices { if i >= lback_periods - 1 { let window = values.slice(s![i + 1 - lback_periods..=i]); @@ -166,7 +187,6 @@ fn freq_period_calc( let half_life = half_life.unwrap(); let weights = expo_weights(lback_periods, half_life); let mut result = Vec::new(); - let period_indices = get_period_indices(dfw, period)?; for &i in &period_indices { if i >= lback_periods - 1 { let window = values.slice(s![i + 1 - lback_periods..=i]); @@ -180,7 +200,10 @@ fn freq_period_calc( } _ => return Err("Invalid lookback method.".into()), }; - println!("Successfully calculated result_series for column: {:?}", col_name); + println!( + "Successfully calculated result_series for column: {:?}", + col_name + ); new_df.with_column(result_series)?; } @@ -273,7 +296,18 @@ pub fn historic_vol( ) -> Result> { println!("Calculating historic volatility with the following parameters:"); println!("xcat: {:?},\ncids: {:?},\nlback_periods: {:?},lback_method: {:?},\nhalf_life: {:?},\nstart: {:?},\nend: {:?},\nest_freq: {:?},\nremove_zeros: {:?},\npostfix: {:?},\nnan_tolerance: {:?}", xcat, cids, lback_periods,lback_method, half_life, start, end, est_freq, remove_zeros, postfix, nan_tolerance); - let mut dfw = pivot_dataframe_by_ticker(df.clone(), Some("value".to_string()))?; + + let rdf = reduce_dataframe( + df.clone(), + cids, + Some(vec![xcat]), + None, + start.clone(), + end.clone(), + false, + )?; + + let mut dfw = pivot_dataframe_by_ticker(rdf, Some("value".to_string()))?; println!("Successfully pivoted the DataFrame."); @@ -333,6 +367,7 @@ pub fn historic_vol( println!("Successfully filtered the DataFrame."); let period = match est_freq.as_str() { + "D" => "daily", "W" => "weekly", "M" => "monthly", _ => return Err("Invalid frequency specified.".into()), @@ -340,7 +375,7 @@ pub fn historic_vol( println!("Successfully got period."); - let dfw = match est_freq.as_str() { + let mut dfw = match est_freq.as_str() { "D" => freq_daily_calc( &dfw, lback_periods, @@ -360,5 +395,17 @@ pub fn historic_vol( )?, }; + // rename each column to include the postfix + for ic in 0..dfw.get_column_names().len() { + let col_name = dfw.get_column_names()[ic].to_string(); + if col_name == "real_date" { + continue; + } + let new_col_name = format!("{}{}", col_name, postfix); + dfw.rename(&col_name, new_col_name.into())?; + } + + dfw = pivot_wide_dataframe_to_qdf(dfw, Some("value".to_string()))?; + Ok(dfw) }