diff --git a/src/agcd.rs b/src/agcd.rs index 59dde41..c882aaa 100644 --- a/src/agcd.rs +++ b/src/agcd.rs @@ -1,6 +1,6 @@ use crate::bkz::bkz_reduce; use crate::deep_lll::deep_lll; -// use crate::lll::lattice_reduce; +use crate::lll::lattice_reduce; use crate::matrix::Matrix; use crate::utils::abs; use lll_rs::l2::bigl2; @@ -8,7 +8,8 @@ use rug::{Integer, Rational}; pub fn agcd(ciphertexts: Vec, noise_bits: usize, algorithm: u8) -> Integer { // 1. Build lattice matrix basis - let basis_matrix = Matrix::new_lattice(noise_bits, ciphertexts.clone()).unwrap(); + let mut basis_matrix = Matrix::new_lattice(noise_bits, ciphertexts.clone()).unwrap(); + println!("basis: {:?}\n\n\n\n", basis_matrix); // 2. reduce with LLL, and extract first element of shortest vector let mut lll_matrix = basis_matrix.to_lll_matrix(); @@ -27,10 +28,10 @@ pub fn agcd(ciphertexts: Vec, noise_bits: usize, algorithm: u8) -> Inte let reduced = deep_lll(basis_matrix.clone(), Rational::from((51, 100))).unwrap(); reduced.columns[0][0].clone() } - // 3u8 => { - // lattice_reduce(&mut basis_matrix, 0.51, 0.75); - // lll_matrix[0][0].clone() - // } + 3u8 => { + lattice_reduce(&mut basis_matrix, 0.51, 0.75); + basis_matrix[0][0].clone() + } _ => panic!("Unknown algorithm value: {}", algorithm), }; diff --git a/src/deep_lll.rs b/src/deep_lll.rs index da8eb29..5d499cc 100644 --- a/src/deep_lll.rs +++ b/src/deep_lll.rs @@ -1,5 +1,5 @@ use crate::matrix::Matrix; -use rug::{Rational, Integer}; +use rug::{Integer, Rational}; /// Perform DeepLLL reduction on a given lattice basis represented by Matrix. /// 1/4 < delta < 1. @@ -78,7 +78,12 @@ fn gramm_schmidt(mat: &Matrix) -> (Vec>, Vec) { } /// Size-reduce column k in-place -fn size_reduce(mat: &mut Matrix, mu: &mut [Vec], b_star_sq: &mut [Rational], k: usize) { +fn size_reduce( + mat: &mut Matrix, + mu: &mut [Vec], + b_star_sq: &mut [Rational], + k: usize, +) { let mut updated = true; while updated { updated = false; diff --git a/src/lll.rs b/src/lll.rs index a77dade..3c379dc 100644 --- a/src/lll.rs +++ b/src/lll.rs @@ -1,22 +1,122 @@ +use lll_rs::vector::BigVector; +use lll_rs::matrix::Matrix as LLLMatrix; use crate::matrix::Matrix; -use lll_rs::{matrix::Matrix as LLLMatrix, vector::BigVector}; use rug::{Integer, Rational}; use std::cmp::max; -use std::ops::Sub; +/// Lattice reduction (L² algorithm, improved LLL) +pub fn lattice_reduce(basis: &mut Matrix, eta: f64, delta: f64) { + assert!(0.25 < delta && delta < 1.); + assert!(0.5 < eta && eta * eta < delta); + let d = basis.n; + let mut gram: Matrix = Matrix::init(d, d); // Gram matrix (upper triangular) + let mut r: Matrix = Matrix::init(d, d); // r_ij matrix + let mut mu: Matrix = Matrix::init(d, d); // Gram coefficient matrix + // Gram matrix + for i in 0..d { + for j in 0..=i { + gram[i][j] = basis[i].clone() * basis[j].clone(); + } + } + let eta_minus = Rational::from_f64((eta + 0.5) / 2.).unwrap(); + let delta_plus = Rational::from_f64((delta + 1.) / 2.).unwrap(); + + r[0][0] = Rational::from(&gram[0][0]); + let mut k = 1; + while k < d { + size_reduce( + k, + d, + basis, + &mut gram, + &mut mu, + &mut r, + Rational::from(&eta_minus), + ); + let delta_criterion = Rational::from(&delta_plus * &r[k - 1][k - 1]); + let scalar_criterion = &r[k][k] + Rational::from(&mu[k][k - 1]).square() * &r[k - 1][k - 1]; + // Lovazs condition + if delta_criterion < scalar_criterion { + k += 1; + } else { + basis.swap(k, k - 1); + // Updating Gram matrix + for j in 0..d { + if j < k { + gram[k][j] = basis[k].clone() * basis[j].clone(); + gram[k - 1][j] = basis[k - 1].clone() * basis[j].clone(); + } else { + gram[j][k] = basis[k].clone() * basis[j].clone(); + gram[j][k - 1] = basis[k - 1].clone() * basis[j].clone(); + } + } + // Updating mu and r + for i in 0..=k { + for j in 0..=i { + r[i][j] = Rational::from(&gram[i][j]) + - (0..j) + .map(|index| Rational::from(&mu[j][index] * &r[i][index])) + .sum::(); + mu[i][j] = Rational::from(&r[i][j] / &r[j][j]); + } + } + k = max(1, k - 1); + } + } +} + + +fn size_reduce( + k: usize, + d: usize, + basis: &mut Matrix, + gram: &mut Matrix, + mu: &mut Matrix, + r: &mut Matrix, + eta: Rational, +) { + // Update mu and r + for i in 0..=k { + r[k][i] = Rational::from(&gram[k][i]) + - (0..i) + .map(|index| Rational::from(&mu[i][index] * &r[k][index])) + .sum::(); + mu[k][i] = Rational::from(&r[k][i] / &r[i][i]); + } + + if (0..k).any(|index| mu[k][index] > eta) { + for i in (0..k).rev() { + let (_, x) = Rational::from(&mu[k][i]).fract_round(Integer::new()); + basis[k] = basis[k].clone() - basis[i].mul_scalar(&x); + // Updating Gram matrix + for j in 0..d { + if j < k { + gram[k][j] = basis[k].clone() * basis[j].clone(); + } else { + gram[j][k] = basis[k].clone() * basis[j].clone(); + } + } + for j in 0..i { + let shift = Rational::from(&mu[i][j]); + mu[k][j] -= Rational::from(&x) * shift; + } + } + size_reduce(k, d, basis, gram, mu, r, eta); + } +} + + +/// conversion to LLLMatrix impl Matrix { pub fn to_lll_matrix(&self) -> LLLMatrix { let n = self.n; let mut lll_mat = LLLMatrix::init(n, n); - for row_idx in 0..n { let mut elements = Vec::with_capacity(n); - for col_idx in 0..n { - let val = self[(row_idx, col_idx)].clone(); + let val = self[(col_idx, row_idx)].clone(); elements.push(val); } - lll_mat[row_idx] = BigVector::from_vector(elements); } lll_mat diff --git a/src/matrix.rs b/src/matrix.rs index 67e14a3..a164906 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -1,9 +1,6 @@ use crate::vector::Vector; use rug::{ops::Pow, Integer}; -use std::{ - fmt, - ops::{Index, IndexMut}, -}; +use std::ops::{Index, IndexMut}; #[derive(Debug, PartialEq, Clone)] pub struct Matrix { @@ -12,12 +9,19 @@ pub struct Matrix { pub columns: Vec>, } -impl Matrix { +impl Matrix { + pub fn init(n: usize, m: usize) -> Self { + let mut columns = Vec::with_capacity(n); + for _ in 0..n { + columns.push(Vector::init(m)); + } + Self { n, m, columns } + } + pub fn new(n: usize, m: usize, values: Vec) -> Option { if n * m != values.len() { return None; } - // avoid requiring Vec: Clone by building manually let mut columns = Vec::with_capacity(n); for _ in 0..n { columns.push(Vec::with_capacity(m)); @@ -32,6 +36,10 @@ impl Matrix { columns: columns.into_iter().map(Vector::from_vec).collect(), }) } + + pub fn swap(&mut self, i: usize, j: usize) { + self.columns.swap(i, j); + } } impl Matrix @@ -40,19 +48,22 @@ where { pub fn new_lattice(noise_bits: usize, ciphertexts: Vec) -> Option { let n = ciphertexts.len(); - let mut columns = vec![Vec::with_capacity(n); n]; + if n < 1 { + return None; + } + let mut columns = vec![vec![]; n]; - // First row: [2^(noise+1), ciphertexts[1], ciphertexts[2], ...] + // First column: [2^(noise_bits+1), ciphertexts[1], ciphertexts[2], ..., ciphertexts[n-1]] let two_pow = Integer::from(2u64).pow((noise_bits + 1) as u32); columns[0].push(T::from(two_pow)); for i in 1..n { - columns[i].push(ciphertexts[i].clone()); + columns[0].push(ciphertexts[i].clone()); } - // Subsequent rows form identity matrix with ciphertexts[0] - for i in 1..n { - for (j, column) in columns.iter_mut().enumerate().take(n) { - column.push(if i == j { + // Subsequent columns: identity matrix with ciphertexts[0] + for j in 1..n { + for i in 0..n { + columns[j].push(if i == j { ciphertexts[0].clone() } else { T::default() @@ -89,6 +100,27 @@ impl IndexMut<(usize, usize)> for Matrix { } } +impl Index for Matrix { + type Output = Vector; + /// `matrix[col]` yields a `&Vector` representing that column. + fn index(&self, col: usize) -> &Self::Output { + if col >= self.columns.len() { + panic!("Matrix column index out of bounds"); + } + &self.columns[col] + } +} + +impl IndexMut for Matrix { + /// `matrix[col] = …` or `matrix[col][row] = …` to mutate a column (or element). + fn index_mut(&mut self, col: usize) -> &mut Self::Output { + if col >= self.columns.len() { + panic!("Matrix column index out of bounds"); + } + &mut self.columns[col] + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/vector.rs b/src/vector.rs index c1b6ce7..a08e289 100644 --- a/src/vector.rs +++ b/src/vector.rs @@ -4,7 +4,7 @@ use std::{ ops::{Add, Index, IndexMut, Mul, Sub}, }; -#[derive(Clone, PartialEq)] +#[derive(PartialEq)] pub struct Vector { pub elements: Vec, } @@ -44,6 +44,14 @@ where } } +impl Clone for Vector { + fn clone(&self) -> Self { + Vector { + elements: self.elements.clone(), + } + } +} + impl Add for Vector where T: Add,