diff --git a/src/agcd.rs b/src/agcd.rs index 6a67977..59dde41 100644 --- a/src/agcd.rs +++ b/src/agcd.rs @@ -1,5 +1,6 @@ use crate::bkz::bkz_reduce; use crate::deep_lll::deep_lll; +// use crate::lll::lattice_reduce; use crate::matrix::Matrix; use crate::utils::abs; use lll_rs::l2::bigl2; @@ -26,6 +27,10 @@ pub fn agcd(ciphertexts: Vec, noise_bits: usize, algorithm: u8) -> Inte let reduced = deep_lll(basis_matrix.clone(), Rational::from((51, 100))).unwrap(); reduced.columns[0][0].clone() } + // 3u8 => { + // lattice_reduce(&mut basis_matrix, 0.51, 0.75); + // lll_matrix[0][0].clone() + // } _ => panic!("Unknown algorithm value: {}", algorithm), }; diff --git a/src/deep_lll.rs b/src/deep_lll.rs index e85a14b..da8eb29 100644 --- a/src/deep_lll.rs +++ b/src/deep_lll.rs @@ -1,9 +1,9 @@ use crate::matrix::Matrix; -use rug::Rational; +use rug::{Rational, Integer}; /// Perform DeepLLL reduction on a given lattice basis represented by Matrix. /// 1/4 < delta < 1. -pub fn deep_lll(mut mat: Matrix, delta: Rational) -> Option { +pub fn deep_lll(mut mat: Matrix, delta: Rational) -> Option> { let n = mat.n; let (mut mu, mut b_star_sq) = gramm_schmidt(&mat); let mut k = 2; @@ -12,7 +12,10 @@ pub fn deep_lll(mut mat: Matrix, delta: Rational) -> Option { while k <= n { if iterations >= MAX_ITERATIONS { - eprintln!("Warning: DeepLLL did not converge after {} iterations", MAX_ITERATIONS); + eprintln!( + "Warning: DeepLLL did not converge after {} iterations", + MAX_ITERATIONS + ); return Some(mat); } iterations += 1; @@ -40,7 +43,7 @@ pub fn deep_lll(mut mat: Matrix, delta: Rational) -> Option { } /// Compute Gram–Schmidt coefficients and squared norms of orthogonal vectors b*_i as Rationals. -fn gramm_schmidt(mat: &Matrix) -> (Vec>, Vec) { +fn gramm_schmidt(mat: &Matrix) -> (Vec>, Vec) { let n = mat.n; let m = mat.m; let mut mu = vec![vec![Rational::from((0, 1)); n]; n]; @@ -75,7 +78,7 @@ fn gramm_schmidt(mat: &Matrix) -> (Vec>, Vec) { } /// Size-reduce column k in-place -fn size_reduce(mat: &mut Matrix, mu: &mut [Vec], b_star_sq: &mut [Rational], k: usize) { +fn size_reduce(mat: &mut Matrix, mu: &mut [Vec], b_star_sq: &mut [Rational], k: usize) { let mut updated = true; while updated { updated = false; @@ -100,13 +103,13 @@ fn size_reduce(mat: &mut Matrix, mu: &mut [Vec], b_star_sq: &mut [Rati } /// Deep insertion: move column k into position i (1-based), shifting intermediate columns right. -fn deep_insert(mat: &mut Matrix, i: usize, k: usize) { +fn deep_insert(mat: &mut Matrix, i: usize, k: usize) { let col = mat.columns.remove(k - 1); mat.columns.insert(i - 1, col); } /// Compute squared Euclidean norm of column k as a Rational. -fn norm_sq(mat: &Matrix, k: usize) -> Rational { +fn norm_sq(mat: &Matrix, k: usize) -> Rational { let mut sum = Rational::from((0, 1)); for row in 0..mat.n { let v = mat[(k - 1, row)].clone(); diff --git a/src/lll.rs b/src/lll.rs index 3ee5513..a77dade 100644 --- a/src/lll.rs +++ b/src/lll.rs @@ -1,8 +1,10 @@ -use lll_rs::{matrix::Matrix as LLLMatrix, vector::BigVector}; - use crate::matrix::Matrix; +use lll_rs::{matrix::Matrix as LLLMatrix, vector::BigVector}; +use rug::{Integer, Rational}; +use std::cmp::max; +use std::ops::Sub; -impl Matrix { +impl Matrix { pub fn to_lll_matrix(&self) -> LLLMatrix { let n = self.n; let mut lll_mat = LLLMatrix::init(n, n); diff --git a/src/matrix.rs b/src/matrix.rs index f1183c2..67e14a3 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -1,27 +1,27 @@ -use crate::vector::IntVector; -use rug::ops::Pow; -use rug::Integer; -use std::ops::{Index, IndexMut}; - -macro_rules! int { - ($x:expr) => { - rug::Integer::from($x) - }; -} +use crate::vector::Vector; +use rug::{ops::Pow, Integer}; +use std::{ + fmt, + ops::{Index, IndexMut}, +}; #[derive(Debug, PartialEq, Clone)] -pub struct Matrix { +pub struct Matrix { pub n: usize, // number of columns pub m: usize, // number of rows - pub columns: Vec, + pub columns: Vec>, } -impl Matrix { - pub fn new(n: usize, m: usize, values: Vec) -> Option { +impl Matrix { + pub fn new(n: usize, m: usize, values: Vec) -> Option { if n * m != values.len() { return None; } - let mut columns = vec![Vec::with_capacity(m); n]; + // avoid requiring Vec: Clone by building manually + let mut columns = Vec::with_capacity(n); + for _ in 0..n { + columns.push(Vec::with_capacity(m)); + } for (i, value) in values.into_iter().enumerate() { let col = i % n; columns[col].push(value); @@ -29,39 +29,47 @@ impl Matrix { Some(Matrix { n, m, - columns: columns.into_iter().map(IntVector::from_vec).collect(), + columns: columns.into_iter().map(Vector::from_vec).collect(), }) } +} - pub fn new_lattice(noise_bits: usize, ciphertexts: Vec) -> Option { +impl Matrix +where + T: Clone + Default + From, +{ + pub fn new_lattice(noise_bits: usize, ciphertexts: Vec) -> Option { let n = ciphertexts.len(); let mut columns = vec![Vec::with_capacity(n); n]; - columns[0].push(int!(2u64).pow((noise_bits + 1) as u32)); + // First row: [2^(noise+1), ciphertexts[1], ciphertexts[2], ...] + let two_pow = Integer::from(2u64).pow((noise_bits + 1) as u32); + columns[0].push(T::from(two_pow)); for i in 1..n { columns[i].push(ciphertexts[i].clone()); } + // Subsequent rows form identity matrix with ciphertexts[0] for i in 1..n { for (j, column) in columns.iter_mut().enumerate().take(n) { - if i == j { - column.push(ciphertexts[0].clone()); + column.push(if i == j { + ciphertexts[0].clone() } else { - column.push(int!(0)); - } + T::default() + }); } } Some(Matrix { n, m: n, - columns: columns.into_iter().map(IntVector::from_vec).collect(), + columns: columns.into_iter().map(Vector::from_vec).collect(), }) } } -impl Index<(usize, usize)> for Matrix { - type Output = Integer; +impl Index<(usize, usize)> for Matrix { + type Output = T; fn index(&self, index: (usize, usize)) -> &Self::Output { let (col, row) = index; if row >= self.n || col >= self.m { @@ -71,7 +79,7 @@ impl Index<(usize, usize)> for Matrix { } } -impl IndexMut<(usize, usize)> for Matrix { +impl IndexMut<(usize, usize)> for Matrix { fn index_mut(&mut self, index: (usize, usize)) -> &mut Self::Output { let (col, row) = index; if row >= self.n || col >= self.m { @@ -84,16 +92,23 @@ impl IndexMut<(usize, usize)> for Matrix { #[cfg(test)] mod tests { use super::*; - use std::panic; + use rug::Rational; + + macro_rules! int { + ($x:expr) => { + Integer::from($x) + }; + } + + macro_rules! rational { + ($x:expr) => { + Rational::from(Integer::from($x)) + }; + } #[test] - fn simple_dimensions_and_index() { + fn test_integer_matrix() { let m = Matrix::new(2, 2, vec![int!(1), int!(2), int!(3), int!(4)]).unwrap(); - assert_eq!(m.n, 2); - assert_eq!(m.m, 2); - - // values: [1,2,3,4] - // columns: [[1,3], [2,4]] assert_eq!(m[(0, 0)], int!(1)); assert_eq!(m[(0, 1)], int!(2)); assert_eq!(m[(1, 0)], int!(3)); @@ -101,47 +116,26 @@ mod tests { } #[test] - fn indexes_and_mutation() { - let mut m = Matrix::new(2, 2, vec![int!(1), int!(2), int!(3), int!(4)]).unwrap(); - assert_eq!(m[(1, 0)], int!(3)); - - m[(1, 0)] = int!(5); - assert_eq!(m[(1, 0)], int!(5)); - - let m2 = Matrix::new( - 3, + fn test_rational_matrix() { + let m = Matrix::new( 2, - vec![int!(1), int!(2), int!(3), int!(4), int!(5), int!(6)], + 2, + vec![rational!(1), rational!(2), rational!(3), rational!(4)], ) .unwrap(); - - assert_eq!(m2[(0, 2)], int!(3)); - assert_eq!(m2[(1, 0)], int!(4)); - - let result = panic::catch_unwind(|| { - let _ = m2[(2, 0)]; - }); - assert!(result.is_err(), "Expected panic on m2[(2, 0)]"); - - let result2 = panic::catch_unwind(|| { - let _ = m2[(0, 3)]; - }); - assert!(result2.is_err(), "Expected panic on m2[(0, 3)]"); + assert_eq!(m[(0, 0)], rational!(1)); + assert_eq!(m[(0, 1)], rational!(2)); + assert_eq!(m[(1, 0)], rational!(3)); + assert_eq!(m[(1, 1)], rational!(4)); } #[test] - fn test_new_lattice_layout() { + fn test_lattice_matrix() { let ciphertexts = vec![int!(5), int!(8), int!(12)]; let noise_bits = 2; - let lattice = Matrix::new_lattice(noise_bits, ciphertexts.clone()).unwrap(); + let lattice = Matrix::new_lattice(noise_bits, ciphertexts).unwrap(); - assert_eq!(lattice.n, 3); - assert_eq!(lattice.m, 3); - - // 1st column = [2^(noise+1), 0, 0] = [8,0,0] - // 2nd column = [ciphertexts[1], ciphertexts[0], 0] = [8,5,0] - // 3rd column = [ciphertexts[2], 0, ciphertexts[0]] = [12,0,5] - let expected_flat = vec![ + let expected = vec![ int!(8), int!(8), int!(12), @@ -152,14 +146,24 @@ mod tests { int!(0), int!(5), ]; - - let mut actual_flat = Vec::with_capacity(9); - for col in 0..lattice.m { - for row in 0..lattice.n { - actual_flat.push(lattice[(col, row)].clone()); + let mut actual = Vec::new(); + for col in 0..3 { + for row in 0..3 { + actual.push(lattice[(col, row)].clone()); } } + assert_eq!(actual, expected); + } - assert_eq!(actual_flat, expected_flat); + #[test] + fn test_rational_lattice() { + let ciphertexts = vec![rational!(5), rational!(8), rational!(12)]; + let noise_bits = 2; + let lattice = Matrix::new_lattice(noise_bits, ciphertexts).unwrap(); + + let two_pow = Rational::from(Integer::from(2u64).pow(3)); + assert_eq!(lattice[(0, 0)], two_pow); + assert_eq!(lattice[(0, 1)], rational!(8)); + assert_eq!(lattice[(0, 2)], rational!(12)); } } diff --git a/src/vector.rs b/src/vector.rs index 503a8b2..c1b6ce7 100644 --- a/src/vector.rs +++ b/src/vector.rs @@ -1,42 +1,53 @@ -use rug::Integer; use std::{ fmt, + iter::Sum, ops::{Add, Index, IndexMut, Mul, Sub}, }; -macro_rules! int { - ($x:expr) => { - rug::Integer::from($x) - }; -} - #[derive(Clone, PartialEq)] -pub struct IntVector { - elements: Vec, +pub struct Vector { + pub elements: Vec, } -impl IntVector { +impl Vector { pub fn init(size: usize) -> Self { - Self { - elements: vec![Default::default(); size], + let mut elements = Vec::with_capacity(size); + for _ in 0..size { + elements.push(T::default()); } + Self { elements } } +} - pub fn from_vec(elements: Vec) -> Self { +impl Vector { + pub fn from_vec(elements: Vec) -> Self { Self { elements } } pub fn size(&self) -> usize { self.elements.len() } +} - pub fn mul_scalar(&self, other: &Integer) -> Self { - let n = self.size(); - Self::from_vec((0..n).map(|i| int!(&self.elements[i] * other)).collect()) +impl Vector +where + T: Mul + Clone, +{ + pub fn mul_scalar(&self, scalar: &T) -> Self { + let elements = self + .elements + .iter() + .cloned() + .map(|e| e * scalar.clone()) + .collect(); + Self { elements } } } -impl Add for IntVector { +impl Add for Vector +where + T: Add, +{ type Output = Self; fn add(self, other: Self) -> Self::Output { assert_eq!(self.size(), other.size()); @@ -46,11 +57,14 @@ impl Add for IntVector { .zip(other.elements) .map(|(a, b)| a + b) .collect(); - IntVector::from_vec(elements) + Vector::from_vec(elements) } } -impl Sub for IntVector { +impl Sub for Vector +where + T: Sub, +{ type Output = Self; fn sub(self, other: Self) -> Self::Output { assert_eq!(self.size(), other.size()); @@ -60,36 +74,40 @@ impl Sub for IntVector { .zip(other.elements) .map(|(a, b)| a - b) .collect(); - IntVector::from_vec(elements) + Vector::from_vec(elements) } } -impl Mul for IntVector { - type Output = Integer; - fn mul(self, other: Self) -> Self::Output { - let n = self.size(); - assert_eq!(n, other.size()); - (0..n) - .map(|i| Integer::from(&self.elements[i] * &other.elements[i])) +impl Mul for Vector +where + T: Mul + Sum, +{ + type Output = T; + fn mul(self, other: Self) -> T { + assert_eq!(self.size(), other.size()); + self.elements + .into_iter() + .zip(other.elements) + .map(|(a, b)| a * b) .sum() } } -impl Index for IntVector { - type Output = Integer; +impl Index for Vector { + type Output = T; - fn index(&self, index: usize) -> &Integer { + fn index(&self, index: usize) -> &T { &self.elements[index] } } -impl IndexMut for IntVector { - fn index_mut(&mut self, index: usize) -> &mut Integer { +impl IndexMut for Vector { + fn index_mut(&mut self, index: usize) -> &mut T { &mut self.elements[index] } } -impl fmt::Debug for IntVector { +impl fmt::Debug for Vector { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{:?}", self.elements) } @@ -98,10 +116,23 @@ impl fmt::Debug for IntVector { #[cfg(test)] mod tests { use super::*; + use rug::{Integer, Rational}; + + macro_rules! int { + ($x:expr) => { + Integer::from($x) + }; + } + + macro_rules! rational { + ($x:expr) => { + Rational::from(Integer::from($x)) + }; + } #[test] - fn test_from_vec() { - let v = IntVector::from_vec(vec![int!(1), int!(2), int!(3)]); + fn test_integer_vector_from_vec() { + let v = Vector::from_vec(vec![int!(1), int!(2), int!(3)]); assert_eq!(v.size(), 3); assert_eq!(v[0], int!(1)); assert_eq!(v[1], int!(2)); @@ -109,9 +140,18 @@ mod tests { } #[test] - fn test_add_vectors() { - let v1 = IntVector::from_vec(vec![int!(1), int!(2), int!(3)]); - let v2 = IntVector::from_vec(vec![int!(4), int!(5), int!(6)]); + fn test_rational_vector_from_vec() { + let v = Vector::from_vec(vec![rational!(1), rational!(2), rational!(3)]); + assert_eq!(v.size(), 3); + assert_eq!(v[0], rational!(1)); + assert_eq!(v[1], rational!(2)); + assert_eq!(v[2], rational!(3)); + } + + #[test] + fn test_add_integer_vectors() { + let v1 = Vector::from_vec(vec![int!(1), int!(2), int!(3)]); + let v2 = Vector::from_vec(vec![int!(4), int!(5), int!(6)]); let result = v1 + v2; assert_eq!(result[0], int!(5)); assert_eq!(result[1], int!(7)); @@ -119,9 +159,9 @@ mod tests { } #[test] - fn test_sub_vectors() { - let v1 = IntVector::from_vec(vec![int!(5), int!(7), int!(9)]); - let v2 = IntVector::from_vec(vec![int!(4), int!(5), int!(6)]); + fn test_sub_integer_vectors() { + let v1 = Vector::from_vec(vec![int!(5), int!(7), int!(9)]); + let v2 = Vector::from_vec(vec![int!(4), int!(5), int!(6)]); let result = v1 - v2; assert_eq!(result[0], int!(1)); assert_eq!(result[1], int!(2)); @@ -129,8 +169,8 @@ mod tests { } #[test] - fn test_scalar_multiplication() { - let v = IntVector::from_vec(vec![int!(2), int!(3), int!(4)]); + fn test_scalar_multiplication_integer() { + let v = Vector::from_vec(vec![int!(2), int!(3), int!(4)]); let scalar = int!(5); let result = v.mul_scalar(&scalar); assert_eq!(result[0], int!(10)); @@ -139,17 +179,37 @@ mod tests { } #[test] - fn test_dot_product() { - let v1 = IntVector::from_vec(vec![int!(1), int!(2), int!(3)]); - let v2 = IntVector::from_vec(vec![int!(4), int!(5), int!(6)]); + fn test_dot_product_integer() { + let v1 = Vector::from_vec(vec![int!(1), int!(2), int!(3)]); + let v2 = Vector::from_vec(vec![int!(4), int!(5), int!(6)]); let dot = v1 * v2; - assert_eq!(dot, int!(32)); // 1*4 + 2*5 + 3*6 = 32 + assert_eq!(dot, int!(32)); } #[test] - fn test_indexing_mut() { - let mut v = IntVector::from_vec(vec![int!(1), int!(2), int!(3)]); + fn test_indexing_mut_integer() { + let mut v = Vector::from_vec(vec![int!(1), int!(2), int!(3)]); v[1] += int!(10); assert_eq!(v[1], int!(12)); } + + #[test] + fn test_add_rational_vectors() { + let v1 = Vector::from_vec(vec![rational!(1), rational!(2), rational!(3)]); + let v2 = Vector::from_vec(vec![rational!(4), rational!(5), rational!(6)]); + let result = v1 + v2; + assert_eq!(result[0], rational!(5)); + assert_eq!(result[1], rational!(7)); + assert_eq!(result[2], rational!(9)); + } + + #[test] + fn test_scalar_multiplication_rational() { + let v = Vector::from_vec(vec![rational!(2), rational!(3), rational!(4)]); + let scalar = rational!(5); + let result = v.mul_scalar(&scalar); + assert_eq!(result[0], rational!(10)); + assert_eq!(result[1], rational!(15)); + assert_eq!(result[2], rational!(20)); + } }