Enhance Matrix implementation with generic filled method and add NaN support

This commit is contained in:
Palash Tyagi 2025-07-07 21:20:57 +01:00
parent e48ce7d6d7
commit 2a63e6d5ab

View File

@ -361,17 +361,18 @@ impl<T: Clone> Matrix<T> {
} }
Matrix::from_vec(data, n, self.cols) Matrix::from_vec(data, n, self.cols)
} }
}
impl Matrix<f64> {
/// Creates a new matrix filled with a specific value of the specified size. /// Creates a new matrix filled with a specific value of the specified size.
pub fn filled(rows: usize, cols: usize, value: f64) -> Self { pub fn filled(rows: usize, cols: usize, value: T) -> Self {
Matrix { Matrix {
rows, rows,
cols, cols,
data: vec![value; rows * cols], // Fill with the specified value data: vec![value; rows * cols], // Fill with the specified value
} }
} }
}
impl Matrix<f64> {
/// Creates a new matrix filled with zeros of the specified size. /// Creates a new matrix filled with zeros of the specified size.
pub fn zeros(rows: usize, cols: usize) -> Self { pub fn zeros(rows: usize, cols: usize) -> Self {
Matrix::filled(rows, cols, 0.0) Matrix::filled(rows, cols, 0.0)
@ -381,6 +382,11 @@ impl Matrix<f64> {
pub fn ones(rows: usize, cols: usize) -> Self { pub fn ones(rows: usize, cols: usize) -> Self {
Matrix::filled(rows, cols, 1.0) Matrix::filled(rows, cols, 1.0)
} }
/// Creates a new matrix filled with NaN values of the specified size.
pub fn nan(rows: usize, cols: usize) -> Matrix<f64> {
Matrix::filled(rows, cols, f64::NAN)
}
} }
impl<T> Index<(usize, usize)> for Matrix<T> { impl<T> Index<(usize, usize)> for Matrix<T> {
@ -1929,5 +1935,19 @@ mod tests {
assert_eq!(m.rows(), 2); assert_eq!(m.rows(), 2);
assert_eq!(m.cols(), 2); assert_eq!(m.cols(), 2);
assert_eq!(m.data(), &[42.5, 42.5, 42.5, 42.5]); assert_eq!(m.data(), &[42.5, 42.5, 42.5, 42.5]);
// test with an integer matrix
let m = Matrix::<i32>::filled(2, 3, 7);
assert_eq!(m.rows(), 2);
assert_eq!(m.cols(), 3);
assert_eq!(m.data(), &[7, 7, 7, 7, 7, 7]);
// test with nans
let m = Matrix::nan(3, 3);
assert_eq!(m.rows(), 3);
assert_eq!(m.cols(), 3);
for &value in m.data() {
assert!(value.is_nan(), "Expected NaN, got {}", value);
}
} }
} }