diff --git a/scripts/gen_values.py b/scripts/gen_values.py index af0e32a..9b991f7 100644 --- a/scripts/gen_values.py +++ b/scripts/gen_values.py @@ -3,6 +3,7 @@ import random import argparse def generate_test_values(noise_bits, number, p_bits): + sys.set_int_max_str_digits(10000000) p = random.randint(2**(p_bits-1), 2**p_bits) while p % 2 == 0: @@ -10,7 +11,7 @@ def generate_test_values(noise_bits, number, p_bits): max_noise = (1 << noise_bits) - 1 # 2^noise_bits - 1 - a = [str(p * random.randint(1, 100) + random.randint(0, max_noise)) for _ in range(number)] + a = [str(p * random.randint(1, 2) + random.randint(0, max_noise)) for _ in range(number)] return noise_bits, a, p diff --git a/scripts/script.py b/scripts/script.py index 952992f..ad6300b 100644 --- a/scripts/script.py +++ b/scripts/script.py @@ -1,5 +1,9 @@ #!/bin/python3 -import argparse, subprocess, re, tempfile, os +import argparse +import subprocess +import re +import tempfile +import os import numpy as np import matplotlib.pyplot as plt from gen_values import generate_test_file @@ -21,24 +25,22 @@ def run_agcd(input_file): print(f"Error parsing output for input_file={input_file}: {e}") return None - -def plot_curves(noise_bits, p_bits, test_numbers, successes): +def plot_curves(noise_bits, p_bits, test_numbers, success_rates): plt.figure(figsize=(10, 6)) - plt.plot(test_numbers, successes, marker='o') + plt.plot(test_numbers, success_rates, marker='o') plt.xlabel('Number of Test Values') - plt.ylabel('Success (1 = Correct, 0 = Incorrect)') - plt.title(f'Success vs. Number of Test Values\n(noise_bits={noise_bits}, p_bits={p_bits})') + plt.ylabel('Success Rate') + plt.title(f'Success Rate vs. Number of Test Values\n(noise_bits={noise_bits}, p_bits={p_bits})') plt.grid(True) plt.ylim(-0.1, 1.1) - plt.yticks([0, 1]) plt.xticks(test_numbers) - plt.savefig('success_plot.png') + plt.savefig('success_rate_plot.png') plt.show() def main(): parser = argparse.ArgumentParser(description='Test AGCD with varying number of test values.') - parser.add_argument('--noise_bits', type=int, default=8, help='Number of noise bits') - parser.add_argument('--p_bits', type=int, default=128, help='Number of bits for p') + parser.add_argument('--noise_bits', type=int, default=0, help='Number of noise bits') + parser.add_argument('--p_bits', type=int, default=10000, help='Number of bits for p') parser.add_argument('--min_values', type=int, default=2, help='Minimum number of test values') parser.add_argument('--max_values', type=int, default=100, help='Maximum number of test values') args = parser.parse_args() @@ -46,27 +48,32 @@ def main(): noise_bits = args.noise_bits p_bits = args.p_bits test_numbers = range(args.min_values, args.max_values + 1) - successes = [] + success_rates = [] + num_trials = 100 for num_values in test_numbers: - # Create temporary test file - with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as tmp_file: - true_p = generate_test_file(noise_bits, num_values, p_bits, tmp_file.name) + successes = 0 + for _ in range(num_trials): + # Create temporary test file + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as tmp_file: + true_p = generate_test_file(noise_bits, num_values, p_bits, tmp_file.name) - # Run AGCD - recovered_p = run_agcd(tmp_file.name) + # Run AGCD + recovered_p = run_agcd(tmp_file.name) - # Check if recovery was successful - success = 1 if recovered_p != None and abs(recovered_p - true_p) <= 4 else 0 - successes.append(success) + # Check if recovery was successful + if recovered_p is not None and abs(recovered_p-true_p) <= 2000: + successes += 1 - # Clean up - os.unlink(tmp_file.name) + # Clean up + os.unlink(tmp_file.name) - print(f"Number of values: {num_values}, Success: {'Yes' if success else 'No'}") + success_rate = successes / num_trials + success_rates.append(success_rate) + print(f"Number of values: {num_values}, Success rate: {success_rate:.3f} ({successes}/{num_trials})") # Plot the results - plot_curves(noise_bits, p_bits, test_numbers, successes) + plot_curves(noise_bits, p_bits, test_numbers, success_rates) if __name__ == "__main__": main() diff --git a/src/agcd.rs b/src/agcd.rs index 522ffd7..f655507 100644 --- a/src/agcd.rs +++ b/src/agcd.rs @@ -1,16 +1,21 @@ +use crate::bkz::bkz_reduce; use crate::matrix::Matrix; use crate::utils::abs; use lll_rs::l2::bigl2; use rug::Integer; -pub fn agcd(ciphertexts: Vec, noise_bits: usize) -> Integer { +pub fn agcd(ciphertexts: Vec, noise_bits: usize, algorithm: u8) -> Integer { // 1. Build lattice matrix basis let basis_matrix = Matrix::new_lattice(noise_bits, ciphertexts.clone()).unwrap(); // 2. reduce with LLL let mut lll_matrix = basis_matrix.to_lll_matrix(); println!("basis: {:?}", lll_matrix); - bigl2::lattice_reduce(&mut lll_matrix, 0.51, 0.75); + match algorithm { + 0u8 => bigl2::lattice_reduce(&mut lll_matrix, 0.51, 0.75), + 1u8 => bkz_reduce(&mut lll_matrix, 16, 0.75, 0.75, 10), + _ => panic!(), + } println!("basis after reduction: {:?}", lll_matrix); // 3. Extract shortest vector @@ -36,7 +41,10 @@ pub fn agcd(ciphertexts: Vec, noise_bits: usize) -> Integer { let p_guess = abs((x0 - r0) / q0); println!("Recovered p: {}", p_guess); - println!("Approximate GCD with noise_bits={}: {}", noise_bits, p_guess); + println!( + "Approximate GCD with noise_bits={}: {}", + noise_bits, p_guess + ); p_guess } diff --git a/src/bkz.rs b/src/bkz.rs new file mode 100644 index 0000000..3910a04 --- /dev/null +++ b/src/bkz.rs @@ -0,0 +1,122 @@ +use lll_rs::vector::Vector; +use lll_rs::{l2::bigl2, matrix::Matrix as LLLMatrix, vector::BigVector}; +use rug::Integer; +use std::cmp::min; + +/// Gram-Schmidt orthogonalization +fn compute_gram_schmidt( + basis: &LLLMatrix, +) -> (Vec, Vec>, Vec) { + let n = basis.dimensions().1; + let mut orthogonal = Vec::with_capacity(n); + let mut mu = vec![vec![Integer::new(); n]; n]; + let mut b = Vec::with_capacity(n); + + for i in 0..n { + let mut gs = basis[i].clone(); + for j in 0..i { + let numerator = inner_product(&gs, &orthogonal[j]); + let denominator: &Integer = &b[j]; + mu[i][j] = numerator.clone() / denominator.clone(); + + gs = gs.sub(&orthogonal[j].clone().mulf(&mu[i][j].clone())); + } + let norm = inner_product(&gs, &gs); + b.push(norm); + orthogonal.push(gs); + } + + (orthogonal, mu, b) +} + +/// BigVectors product +fn inner_product(a: &BigVector, b: &BigVector) -> Integer { + assert!(a.dimension() == b.dimension()); + let mut sum = Integer::new(); + for i in 0..a.dimension() { + sum += &a[i] * &b[i]; + } + sum +} + +/// vectors orthogonal projection +fn project_block( + basis: &LLLMatrix, + start: usize, + end: usize, + orthogonal: &[BigVector], + mu: &[Vec], +) -> LLLMatrix { + let mut projected = LLLMatrix::init(end - start, basis.dimensions().0); + + for (idx, i) in (start..end).enumerate() { + let mut v = basis[i].clone(); + for (j, ortho_vec) in orthogonal.iter().enumerate().take(start) { + let coeff = &mu[i][j]; + v = v.sub(&ortho_vec.clone().mulf(&coeff.clone())); + } + + projected[idx] = v; + } + + projected +} + +/// BKZ reduction using LLL as SVP‐oracle +pub fn bkz_reduce( + basis: &mut LLLMatrix, + block_size: usize, + delta: f64, + eta: f64, + max_rounds: usize, +) { + let n = basis.dimensions().1; + + for _ in 0..max_rounds { + let (orthogonal, mu, _b) = compute_gram_schmidt(basis); + + for k in 0..n { + let l = min(k + block_size, n); + if l - k < 2 { + continue; + } + + let mut projected = project_block(basis, k, l, &orthogonal, &mu); + bigl2::lattice_reduce(&mut projected, delta, eta); + + for (i_block, i) in (k..l).enumerate() { + basis[i] = projected[i_block].clone(); + } + } + } + bigl2::lattice_reduce(basis, delta, eta); +} + + +#[cfg(test)] +mod tests { + use super::*; + use crate::matrix::Matrix; + use rug::Integer; + + #[test] + fn test_bkz_reduction() { + let ciphertexts = vec![Integer::from(5), Integer::from(8), Integer::from(12)]; + let noise_bits = 2; + + let basis_matrix = Matrix::new_lattice(noise_bits, ciphertexts).unwrap(); + let mut lll_matrix = basis_matrix.to_lll_matrix(); + + bkz_reduce(&mut lll_matrix, 5, 0.51, 0.75, 3); + + let shortest_vector = &lll_matrix[0]; + let mut norm = Integer::from(0); + for i in 0..=lll_matrix.dimensions().0 - 1 { + norm += &shortest_vector[i] * &shortest_vector[i]; + } + assert!( + norm > Integer::from(0), + "Shortest vector should not be zero" + ); + } +} diff --git a/src/main.rs b/src/main.rs index ebf1311..a5ce515 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,5 @@ mod agcd; +mod bkz; mod file; mod lll; mod matrix; @@ -23,6 +24,9 @@ enum Commands { Agcd { /// (default './input.txt') path: Option, + /// Algorithm variant to use (default: 0) + #[arg(short = 'a', long, default_value_t = 0u8)] + algorithm: u8, }, } @@ -30,11 +34,11 @@ fn main() -> std::io::Result<()> { let cli = Cli::parse(); match &cli.command { - Commands::Agcd { path } => { + Commands::Agcd { path, algorithm } => { let path = path.as_deref().unwrap_or(Path::new("./input.txt")); let input = parse_file(path)?; - agcd(input.numbers, input.noise_bits); + agcd(input.numbers, input.noise_bits, *algorithm); } } Ok(())