diff --git a/src/matrix/mat.rs b/src/matrix/mat.rs index 086b715..ceb693e 100644 --- a/src/matrix/mat.rs +++ b/src/matrix/mat.rs @@ -361,17 +361,18 @@ impl Matrix { } Matrix::from_vec(data, n, self.cols) } -} -impl Matrix { /// Creates a new matrix filled with a specific value of the specified size. - pub fn filled(rows: usize, cols: usize, value: f64) -> Self { + pub fn filled(rows: usize, cols: usize, value: T) -> Self { Matrix { rows, cols, data: vec![value; rows * cols], // Fill with the specified value } } +} + +impl Matrix { /// Creates a new matrix filled with zeros of the specified size. pub fn zeros(rows: usize, cols: usize) -> Self { Matrix::filled(rows, cols, 0.0) @@ -381,6 +382,11 @@ impl Matrix { pub fn ones(rows: usize, cols: usize) -> Self { Matrix::filled(rows, cols, 1.0) } + + /// Creates a new matrix filled with NaN values of the specified size. + pub fn nan(rows: usize, cols: usize) -> Matrix { + Matrix::filled(rows, cols, f64::NAN) + } } impl Index<(usize, usize)> for Matrix { @@ -1929,5 +1935,19 @@ mod tests { assert_eq!(m.rows(), 2); assert_eq!(m.cols(), 2); assert_eq!(m.data(), &[42.5, 42.5, 42.5, 42.5]); + + // test with an integer matrix + let m = Matrix::::filled(2, 3, 7); + assert_eq!(m.rows(), 2); + assert_eq!(m.cols(), 3); + assert_eq!(m.data(), &[7, 7, 7, 7, 7, 7]); + + // test with nans + let m = Matrix::nan(3, 3); + assert_eq!(m.rows(), 3); + assert_eq!(m.cols(), 3); + for &value in m.data() { + assert!(value.is_nan(), "Expected NaN, got {}", value); + } } }