rusttest/src/main.rs

74 lines
2.3 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use chrono::NaiveDate;
use rustframe::{
frame::{Frame, RowIndex},
matrix::{BoolOps, Matrix, SeriesOps},
utils::{BDateFreq, BDatesList},
};
fn generate_array(n_cols: usize, n_rows: usize) -> Matrix<f64> {
let matrix = Matrix::from_vec(
(0..n_cols * n_rows).map(|x| x as f64).collect::<Vec<f64>>(),
n_cols,
n_rows,
);
matrix
}
fn test_a() {
let n_periods = 4;
// Four business days starting 20240102
let dates: Vec<NaiveDate> =
BDatesList::from_n_periods("2024-01-02".to_string(), BDateFreq::Daily, n_periods)
.unwrap()
.list()
.unwrap();
let col_names: Vec<String> = vec!["a".to_string(), "b".to_string()];
let ma: Matrix<f64> =
Matrix::from_cols(vec![vec![1.0, 2.0, 3.0, 4.0], vec![5.0, 6.0, 7.0, 8.0]]);
let mb: Matrix<f64> =
Matrix::from_cols(vec![vec![4.0, 3.0, 2.0, 1.0], vec![8.0, 7.0, 6.0, 5.0]]);
let fa: Frame<f64> = Frame::new(
ma.clone(),
col_names.clone(),
Some(RowIndex::Date(dates.clone())),
);
let fb: Frame<f64> = Frame::new(mb, col_names, Some(RowIndex::Date(dates)));
// Math that reads like math
let result: Frame<f64> = &fa * &fb; // elementwise multiply
let total: f64 = result.sum_vertical().iter().sum::<f64>();
assert_eq!(total, 184.0);
// broadcast & reduce
let result: Matrix<f64> = &ma + 1.0; // add scalar
let result: Matrix<f64> = &result - 1.0; // subtract scalar
let result: Matrix<f64> = &result * 2.0; // multiply by scalar
let result: Matrix<f64> = &result / 2.0; // divide by scalar
let check: bool = result.eq_elem(ma.clone()).all();
assert!(check);
// The above math can also be written as:
let check: bool = (&(&(&(&ma + 1.0) - 1.0) * 2.0) / 2.0).eq_elem(ma).all();
assert!(check);
}
fn main() {
// test with 1000x1000
let n_cols = 1000;
let n_rows = 1000;
let ma = generate_array(n_cols, n_rows);
// time the operation
let start = std::time::Instant::now();
// let result = Matrix::from_vec((&ma * &ma).sum_vertical(), n_rows, 1).sum_vertical();
let result = (&ma * &ma).sum_vertical();
let duration = start.elapsed();
// println!("Result: {:?}", result);
println!("Duration: {:?}", duration);
}