use rug::ops::Pow; use std::ops::{Index, IndexMut}; type Element = rug::Integer; macro_rules! int { ($x:expr) => { rug::Integer::from($x) }; } #[derive(Debug, PartialEq)] pub struct Matrix { pub n: usize, // rows pub m: usize, // columns values: Vec, } impl Matrix { pub fn new(n: usize, m: usize, values: Vec) -> Option { if n * m == values.len() { Some(Matrix { n, m, values }) } else { None } } pub fn new_lattice(noise_bits: usize, ciphertexts: Vec) -> Option { let n = ciphertexts.len(); let mut values = Vec::with_capacity(n * n); // First row: [2^(noise_bits+1), ciphertexts[1], ..., ciphertexts[t]] values.push(int!(2u64).pow(noise_bits as u32 + 1)); values.extend_from_slice(&ciphertexts[1..]); // x0 on diagonal, 0 everywhere else let x0 = &ciphertexts[0]; for i in 1..n { let mut row = vec![int!(0); n]; row[i] = x0.clone(); values.extend(row); } Matrix::new(n, n, values) } } impl Index<(usize, usize)> for Matrix { type Output = Element; fn index(&self, index: (usize, usize)) -> &Self::Output { if index.0 >= self.m || index.1 >= self.n { panic!(); } &self.values[(self.n * index.0) + index.1] } } impl IndexMut<(usize, usize)> for Matrix { fn index_mut(&mut self, index: (usize, usize)) -> &mut Self::Output { if index.0 >= self.m || index.1 >= self.n { panic!(); } &mut self.values[(self.n * index.0) + index.1] } } #[cfg(test)] mod tests { use super::*; use std::panic; #[test] fn simple_matrix() { assert_eq!( Matrix { n: 2, m: 2, values: vec![int!(1), int!(2), int!(3), int!(4)], }, Matrix::new(2, 2, vec![int!(1), int!(2), int!(3), int!(4)]).unwrap() ); assert!(Matrix::new(2, 2, vec![int!(1), int!(2)]).is_none()); } #[test] fn indexes() { let mut m = Matrix::new(2, 2, vec![int!(1), int!(2), int!(3), int!(4)]).unwrap(); assert_eq!(m[(0, 0)], int!(1)); assert_eq!(m[(1, 0)], int!(3)); m[(1, 0)] = int!(5); assert_eq!(m[(1, 0)], int!(5)); let m2 = Matrix::new( 3, 2, vec![int!(1), int!(2), int!(3), int!(4), int!(5), int!(6)], ) .unwrap(); assert_eq!(m2[(0, 2)], int!(3)); assert_eq!(m2[(1, 0)], int!(4)); let result = panic::catch_unwind(|| { let _ = m2[(0, 3)]; }); assert!(result.is_err(), "Expected panic on m2[(0, 3)]"); let result2 = panic::catch_unwind(|| { let _ = m2[(2, 0)]; }); assert!(result2.is_err(), "Expected panic on m2[(2, 0)]"); } #[test] fn test_new_lattice() { let ciphertexts = vec![int!(5), int!(8), int!(12)]; let noise_bits = 2; let expected_values = vec![ int!(8), int!(8), int!(12), int!(0), int!(5), int!(0), int!(0), int!(0), int!(5), ]; let lattice = Matrix::new_lattice(noise_bits, ciphertexts).unwrap(); assert_eq!(lattice.n, 3); assert_eq!(lattice.values, expected_values); } }