msyrs/src/utils/misc.rs

97 lines
3.0 KiB
Rust

use polars::prelude::*;
use std::collections::HashMap;
use std::error::Error;
/// Split a ticker string into `cid` and `xcat`.
pub fn split_ticker(ticker: String) -> Result<(String, String), Box<dyn Error>> {
// split by the first underscore character. return the first and second parts.
let parts: Vec<&str> = ticker.splitn(2, '_').collect();
if parts.len() != 2 {
return Err("Invalid ticker format".into());
}
Ok((parts[0].to_string(), parts[1].to_string()))
}
/// Get the `cid` from a ticker string.
#[allow(dead_code)]
pub fn get_cid(ticker: String) -> Result<String, Box<dyn Error>> {
split_ticker(ticker).map(|(cid, _)| cid)
}
/// Get the `xcat` from a ticker string.
#[allow(dead_code)]
pub fn get_xcat(ticker: String) -> Result<String, Box<dyn Error>> {
split_ticker(ticker).map(|(_, xcat)| xcat)
}
/// Get the `cid` and `xcat` from a ticker string.
pub fn create_ticker(cid: &str, xcat: &str) -> String {
format!("{}_{}", cid, xcat)
}
/// Create all possible tickers from a list of `cids` and `xcats`.
pub fn create_intersecting_tickers(cids: &[&str], xcats: &[&str]) -> Vec<String> {
let mut tickers = Vec::new();
for cid in cids {
for xcat in xcats {
tickers.push(create_ticker(cid, xcat));
}
}
tickers
}
/// Backed function to get unique strings from a string column object.
pub fn _get_unique_strs_from_str_column_object(
col: &Column,
) -> Result<Vec<String>, Box<dyn Error>> {
let res = col
.unique()?
.sort(SortOptions::default())?
.drop_nulls()
.str()?
.iter()
.map(|x| x.unwrap_or_default().to_string())
.collect();
Ok(res)
}
/// Get the unique values from a string column in a DataFrame.
pub fn get_unique_from_str_column(
df: &DataFrame,
col: &str,
) -> Result<Vec<String>, Box<dyn Error>> {
_get_unique_strs_from_str_column_object(&df.column(col).unwrap())
}
pub fn get_intersecting_cids_str_func(
cids: &Vec<String>,
xcats: &Vec<String>,
found_tickers: &Vec<String>,
) -> Vec<String> {
// make a hashmap of cids to xcats
let mut cid_xcat_map = HashMap::new();
for ticker in found_tickers {
let (cid, xcat) = split_ticker(ticker.clone()).unwrap();
// if the cid is not in the map, add it
if !cid_xcat_map.contains_key(&cid) {
cid_xcat_map.insert(cid.clone(), vec![xcat.clone()]);
} else {
cid_xcat_map.get_mut(&cid).unwrap().push(xcat.clone());
}
}
let mut found_cids: Vec<String> = cid_xcat_map.keys().map(|x| x.clone()).collect();
found_cids.retain(|x| cids.contains(x));
let mut new_keep_cids: Vec<String> = Vec::new();
for cid in found_cids {
let xcats_for_cid = cid_xcat_map.get(&cid).unwrap();
let mut found_xcats: Vec<String> = xcats_for_cid.iter().map(|x| x.clone()).collect();
found_xcats.retain(|x| xcats.contains(x));
if found_xcats.len() > 0 {
new_keep_cids.push(cid);
}
}
new_keep_cids
}