diff --git a/src/matrix/mat.rs b/src/matrix/mat.rs index 64bf6de..bf4265b 100644 --- a/src/matrix/mat.rs +++ b/src/matrix/mat.rs @@ -310,6 +310,26 @@ impl Matrix { } } +impl Matrix { + /// Creates a new matrix filled with a specific value of the specified size. + pub fn filled(rows: usize, cols: usize, value: f64) -> Self { + Matrix { + rows, + cols, + data: vec![value; rows * cols], // Fill with the specified value + } + } + /// Creates a new matrix filled with zeros of the specified size. + pub fn zeros(rows: usize, cols: usize) -> Self { + Matrix::filled(rows, cols, 0.0) + } + + /// Creates a new matrix filled with ones of the specified size. + pub fn ones(rows: usize, cols: usize) -> Self { + Matrix::filled(rows, cols, 1.0) + } +} + impl Index<(usize, usize)> for Matrix { type Output = T; @@ -1794,4 +1814,25 @@ mod tests { } } } + + #[test] + fn test_matrix_zeros_ones_filled() { + // Test zeros + let m = Matrix::::zeros(2, 3); + assert_eq!(m.rows(), 2); + assert_eq!(m.cols(), 3); + assert_eq!(m.data(), &[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]); + + // Test ones + let m = Matrix::::ones(3, 2); + assert_eq!(m.rows(), 3); + assert_eq!(m.cols(), 2); + assert_eq!(m.data(), &[1.0, 1.0, 1.0, 1.0, 1.0, 1.0]); + + // Test filled + let m = Matrix::::filled(2, 2, 42.5); + assert_eq!(m.rows(), 2); + assert_eq!(m.cols(), 2); + assert_eq!(m.data(), &[42.5, 42.5, 42.5, 42.5]); + } }