diff --git a/src/bkz.rs b/src/bkz.rs index 3910a04..76e4158 100644 --- a/src/bkz.rs +++ b/src/bkz.rs @@ -92,7 +92,6 @@ pub fn bkz_reduce( bigl2::lattice_reduce(basis, delta, eta); } - #[cfg(test)] mod tests { use super::*; diff --git a/src/matrix.rs b/src/matrix.rs index 91b207a..f43bd36 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -1,7 +1,7 @@ use rug::ops::Pow; +use rug::Integer; use std::ops::{Index, IndexMut}; - -type Element = rug::Integer; +use crate::vector::IntVector; macro_rules! int { ($x:expr) => { @@ -13,54 +13,71 @@ macro_rules! int { pub struct Matrix { pub n: usize, // rows pub m: usize, // columns - values: Vec, + columns: Vec, } impl Matrix { - pub fn new(n: usize, m: usize, values: Vec) -> Option { - if n * m == values.len() { - Some(Matrix { n, m, values }) - } else { - None + pub fn new(n: usize, m: usize, values: Vec) -> Option { + if n * m != values.len() { + return None; } + let mut columns = vec![Vec::with_capacity(m); n]; + for (i, value) in values.into_iter().enumerate() { + let col = i % n; + columns[col].push(value); + } + Some(Matrix { + n, + m, + columns: columns.into_iter().map(IntVector::from_vec).collect(), + }) } - pub fn new_lattice(noise_bits: usize, ciphertexts: Vec) -> Option { + pub fn new_lattice(noise_bits: usize, ciphertexts: Vec) -> Option { let n = ciphertexts.len(); - let mut values = Vec::with_capacity(n * n); + let mut columns = vec![Vec::with_capacity(n); n]; - // First row: [2^(noise_bits+1), ciphertexts[1], ..., ciphertexts[t]] - values.push(int!(2u64).pow(noise_bits as u32 + 1)); - values.extend_from_slice(&ciphertexts[1..]); - - // x0 on diagonal, 0 everywhere else - let x0 = &ciphertexts[0]; + columns[0].push(int!(2u64).pow((noise_bits + 1) as u32)); for i in 1..n { - let mut row = vec![int!(0); n]; - row[i] = x0.clone(); - values.extend(row); + columns[i].push(ciphertexts[i].clone()); } - Matrix::new(n, n, values) + for i in 1..n { + for j in 0..n { + if i == j { + columns[j].push(ciphertexts[0].clone()); + } else { + columns[j].push(int!(0)); + } + } + } + + Some(Matrix { + n, + m: n, + columns: columns.into_iter().map(IntVector::from_vec).collect(), + }) } } impl Index<(usize, usize)> for Matrix { - type Output = Element; + type Output = Integer; fn index(&self, index: (usize, usize)) -> &Self::Output { - if index.0 >= self.m || index.1 >= self.n { - panic!(); + let (col, row) = index; + if row >= self.n || col >= self.m { + panic!("Matrix index out of bounds"); } - &self.values[(self.n * index.0) + index.1] + &self.columns[row][col] } } impl IndexMut<(usize, usize)> for Matrix { fn index_mut(&mut self, index: (usize, usize)) -> &mut Self::Output { - if index.0 >= self.m || index.1 >= self.n { - panic!(); + let (col, row) = index; + if row >= self.n || col >= self.m { + panic!("Matrix index out of bounds"); } - &mut self.values[(self.n * index.0) + index.1] + &mut self.columns[row][col] } } @@ -68,24 +85,24 @@ impl IndexMut<(usize, usize)> for Matrix { mod tests { use super::*; use std::panic; + #[test] - fn simple_matrix() { - assert_eq!( - Matrix { - n: 2, - m: 2, - values: vec![int!(1), int!(2), int!(3), int!(4)], - }, - Matrix::new(2, 2, vec![int!(1), int!(2), int!(3), int!(4)]).unwrap() - ); - assert!(Matrix::new(2, 2, vec![int!(1), int!(2)]).is_none()); + fn simple_dimensions_and_index() { + let m = Matrix::new(2, 2, vec![int!(1), int!(2), int!(3), int!(4)]).unwrap(); + assert_eq!(m.n, 2); + assert_eq!(m.m, 2); + + // values: [1,2,3,4] + // columns: [[1,3], [2,4]] + assert_eq!(m[(0, 0)], int!(1)); + assert_eq!(m[(0, 1)], int!(2)); + assert_eq!(m[(1, 0)], int!(3)); + assert_eq!(m[(1, 1)], int!(4)); } #[test] - fn indexes() { + fn indexes_and_mutation() { let mut m = Matrix::new(2, 2, vec![int!(1), int!(2), int!(3), int!(4)]).unwrap(); - - assert_eq!(m[(0, 0)], int!(1)); assert_eq!(m[(1, 0)], int!(3)); m[(1, 0)] = int!(5); @@ -97,38 +114,43 @@ mod tests { vec![int!(1), int!(2), int!(3), int!(4), int!(5), int!(6)], ) .unwrap(); + assert_eq!(m2[(0, 2)], int!(3)); assert_eq!(m2[(1, 0)], int!(4)); - let result = panic::catch_unwind(|| { - let _ = m2[(0, 3)]; - }); - assert!(result.is_err(), "Expected panic on m2[(0, 3)]"); - let result2 = panic::catch_unwind(|| { - let _ = m2[(2, 0)]; - }); - assert!(result2.is_err(), "Expected panic on m2[(2, 0)]"); + let result = panic::catch_unwind(|| { let _ = m2[(2, 0)]; }); + assert!(result.is_err(), "Expected panic on m2[(2, 0)]"); + + let result2 = panic::catch_unwind(|| { let _ = m2[(0, 3)]; }); + assert!(result2.is_err(), "Expected panic on m2[(0, 3)]"); } #[test] - fn test_new_lattice() { + fn test_new_lattice_layout() { let ciphertexts = vec![int!(5), int!(8), int!(12)]; let noise_bits = 2; + let lattice = Matrix::new_lattice(noise_bits, ciphertexts.clone()).unwrap(); - let expected_values = vec![ - int!(8), - int!(8), - int!(12), - int!(0), - int!(5), - int!(0), - int!(0), - int!(0), - int!(5), + assert_eq!(lattice.n, 3); + assert_eq!(lattice.m, 3); + + // 1st column = [2^(noise+1), 0, 0] = [8,0,0] + // 2nd column = [ciphertexts[1], ciphertexts[0], 0] = [8,5,0] + // 3rd column = [ciphertexts[2], 0, ciphertexts[0]] = [12,0,5] + let expected_flat = vec![ + int!(8), int!(8), int!(12), + int!(0), int!(5), int!(0), + int!(0), int!(0), int!(5), ]; - let lattice = Matrix::new_lattice(noise_bits, ciphertexts).unwrap(); - assert_eq!(lattice.n, 3); - assert_eq!(lattice.values, expected_values); + let mut actual_flat = Vec::with_capacity(9); + for col in 0..lattice.m { + for row in 0..lattice.n { + actual_flat.push(lattice[(col, row)].clone()); + } + } + + assert_eq!(actual_flat, expected_flat); } } + diff --git a/src/vector.rs b/src/vector.rs index aed0625..503a8b2 100644 --- a/src/vector.rs +++ b/src/vector.rs @@ -10,7 +10,7 @@ macro_rules! int { }; } -#[derive(Clone)] +#[derive(Clone, PartialEq)] pub struct IntVector { elements: Vec, }