diff --git a/src/matrix.rs b/src/matrix.rs index f6c1302..8d4b670 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -11,14 +11,15 @@ macro_rules! int { #[derive(Debug, PartialEq)] pub struct Matrix { - pub n: usize, + pub n: usize, //rows + pub m: usize, // columns values: Vec, } impl Matrix { - pub fn new(n: usize, values: Vec) -> Option { - if n.pow(2) == values.len() { - Some(Matrix { n, values }) + pub fn new(n: usize, m: usize, values: Vec) -> Option { + if n*m == values.len() { + Some(Matrix { n, m, values }) } else { None } @@ -40,19 +41,25 @@ impl Matrix { values.extend(row); } - Matrix::new(n, values) + Matrix::new(n, n, values) } } impl Index<(usize, usize)> for Matrix { type Output = Element; fn index(&self, index: (usize, usize)) -> &Self::Output { + if index.0>=self.m || index.1 >= self.n { + panic!(); + } &self.values[(self.n * index.0) + index.1] } } impl IndexMut<(usize, usize)> for Matrix { fn index_mut(&mut self, index: (usize, usize)) -> &mut Self::Output { + if index.0>=self.m || index.1 >= self.n { + panic!(); + } &mut self.values[(self.n * index.0) + index.1] } } @@ -65,22 +72,27 @@ mod tests { assert_eq!( Matrix { n: 2, + m: 2, values: vec![int!(1), int!(2), int!(3), int!(4)], }, - Matrix::new(2, vec![int!(1), int!(2), int!(3), int!(4)]).unwrap() + Matrix::new(2, 2, vec![int!(1), int!(2), int!(3), int!(4)]).unwrap() ); - assert!(Matrix::new(3, vec![int!(1), int!(2)]).is_none()); + assert!(Matrix::new(2, 2, vec![int!(1), int!(2)]).is_none()); } #[test] fn indexes() { - let mut m = Matrix::new(2, vec![int!(1), int!(2), int!(3), int!(4)]).unwrap(); + let mut m = Matrix::new(2, 2, vec![int!(1), int!(2), int!(3), int!(4)]).unwrap(); assert_eq!(m[(0, 0)], int!(1)); assert_eq!(m[(1, 0)], int!(3)); m[(1, 0)] = int!(5); assert_eq!(m[(1, 0)], int!(5)); + + let m2 = Matrix::new(3, 2, vec![int!(1), int!(2), int!(3), int!(4), int!(5), int!(6)]).unwrap(); + assert_eq!(m2[(0, 2)], int!(3)); + assert_eq!(m2[(1, 0)], int!(4)); } #[test]