bkz using LLL oracle

This commit is contained in:
Sam Hadow 2025-05-23 19:25:54 +02:00
parent 34ab20c982
commit 938892acd6
5 changed files with 171 additions and 29 deletions

View File

@ -3,6 +3,7 @@ import random
import argparse import argparse
def generate_test_values(noise_bits, number, p_bits): def generate_test_values(noise_bits, number, p_bits):
sys.set_int_max_str_digits(10000000)
p = random.randint(2**(p_bits-1), 2**p_bits) p = random.randint(2**(p_bits-1), 2**p_bits)
while p % 2 == 0: while p % 2 == 0:
@ -10,7 +11,7 @@ def generate_test_values(noise_bits, number, p_bits):
max_noise = (1 << noise_bits) - 1 # 2^noise_bits - 1 max_noise = (1 << noise_bits) - 1 # 2^noise_bits - 1
a = [str(p * random.randint(1, 100) + random.randint(0, max_noise)) for _ in range(number)] a = [str(p * random.randint(1, 2) + random.randint(0, max_noise)) for _ in range(number)]
return noise_bits, a, p return noise_bits, a, p

View File

@ -1,5 +1,9 @@
#!/bin/python3 #!/bin/python3
import argparse, subprocess, re, tempfile, os import argparse
import subprocess
import re
import tempfile
import os
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from gen_values import generate_test_file from gen_values import generate_test_file
@ -21,24 +25,22 @@ def run_agcd(input_file):
print(f"Error parsing output for input_file={input_file}: {e}") print(f"Error parsing output for input_file={input_file}: {e}")
return None return None
def plot_curves(noise_bits, p_bits, test_numbers, success_rates):
def plot_curves(noise_bits, p_bits, test_numbers, successes):
plt.figure(figsize=(10, 6)) plt.figure(figsize=(10, 6))
plt.plot(test_numbers, successes, marker='o') plt.plot(test_numbers, success_rates, marker='o')
plt.xlabel('Number of Test Values') plt.xlabel('Number of Test Values')
plt.ylabel('Success (1 = Correct, 0 = Incorrect)') plt.ylabel('Success Rate')
plt.title(f'Success vs. Number of Test Values\n(noise_bits={noise_bits}, p_bits={p_bits})') plt.title(f'Success Rate vs. Number of Test Values\n(noise_bits={noise_bits}, p_bits={p_bits})')
plt.grid(True) plt.grid(True)
plt.ylim(-0.1, 1.1) plt.ylim(-0.1, 1.1)
plt.yticks([0, 1])
plt.xticks(test_numbers) plt.xticks(test_numbers)
plt.savefig('success_plot.png') plt.savefig('success_rate_plot.png')
plt.show() plt.show()
def main(): def main():
parser = argparse.ArgumentParser(description='Test AGCD with varying number of test values.') parser = argparse.ArgumentParser(description='Test AGCD with varying number of test values.')
parser.add_argument('--noise_bits', type=int, default=8, help='Number of noise bits') parser.add_argument('--noise_bits', type=int, default=0, help='Number of noise bits')
parser.add_argument('--p_bits', type=int, default=128, help='Number of bits for p') parser.add_argument('--p_bits', type=int, default=10000, help='Number of bits for p')
parser.add_argument('--min_values', type=int, default=2, help='Minimum number of test values') parser.add_argument('--min_values', type=int, default=2, help='Minimum number of test values')
parser.add_argument('--max_values', type=int, default=100, help='Maximum number of test values') parser.add_argument('--max_values', type=int, default=100, help='Maximum number of test values')
args = parser.parse_args() args = parser.parse_args()
@ -46,27 +48,32 @@ def main():
noise_bits = args.noise_bits noise_bits = args.noise_bits
p_bits = args.p_bits p_bits = args.p_bits
test_numbers = range(args.min_values, args.max_values + 1) test_numbers = range(args.min_values, args.max_values + 1)
successes = [] success_rates = []
num_trials = 100
for num_values in test_numbers: for num_values in test_numbers:
# Create temporary test file successes = 0
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as tmp_file: for _ in range(num_trials):
true_p = generate_test_file(noise_bits, num_values, p_bits, tmp_file.name) # Create temporary test file
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as tmp_file:
true_p = generate_test_file(noise_bits, num_values, p_bits, tmp_file.name)
# Run AGCD # Run AGCD
recovered_p = run_agcd(tmp_file.name) recovered_p = run_agcd(tmp_file.name)
# Check if recovery was successful # Check if recovery was successful
success = 1 if recovered_p != None and abs(recovered_p - true_p) <= 4 else 0 if recovered_p is not None and abs(recovered_p-true_p) <= 2000:
successes.append(success) successes += 1
# Clean up # Clean up
os.unlink(tmp_file.name) os.unlink(tmp_file.name)
print(f"Number of values: {num_values}, Success: {'Yes' if success else 'No'}") success_rate = successes / num_trials
success_rates.append(success_rate)
print(f"Number of values: {num_values}, Success rate: {success_rate:.3f} ({successes}/{num_trials})")
# Plot the results # Plot the results
plot_curves(noise_bits, p_bits, test_numbers, successes) plot_curves(noise_bits, p_bits, test_numbers, success_rates)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -1,16 +1,21 @@
use crate::bkz::bkz_reduce;
use crate::matrix::Matrix; use crate::matrix::Matrix;
use crate::utils::abs; use crate::utils::abs;
use lll_rs::l2::bigl2; use lll_rs::l2::bigl2;
use rug::Integer; use rug::Integer;
pub fn agcd(ciphertexts: Vec<Integer>, noise_bits: usize) -> Integer { pub fn agcd(ciphertexts: Vec<Integer>, noise_bits: usize, algorithm: u8) -> Integer {
// 1. Build lattice matrix basis // 1. Build lattice matrix basis
let basis_matrix = Matrix::new_lattice(noise_bits, ciphertexts.clone()).unwrap(); let basis_matrix = Matrix::new_lattice(noise_bits, ciphertexts.clone()).unwrap();
// 2. reduce with LLL // 2. reduce with LLL
let mut lll_matrix = basis_matrix.to_lll_matrix(); let mut lll_matrix = basis_matrix.to_lll_matrix();
println!("basis: {:?}", lll_matrix); println!("basis: {:?}", lll_matrix);
bigl2::lattice_reduce(&mut lll_matrix, 0.51, 0.75); match algorithm {
0u8 => bigl2::lattice_reduce(&mut lll_matrix, 0.51, 0.75),
1u8 => bkz_reduce(&mut lll_matrix, 16, 0.75, 0.75, 10),
_ => panic!(),
}
println!("basis after reduction: {:?}", lll_matrix); println!("basis after reduction: {:?}", lll_matrix);
// 3. Extract shortest vector // 3. Extract shortest vector
@ -36,7 +41,10 @@ pub fn agcd(ciphertexts: Vec<Integer>, noise_bits: usize) -> Integer {
let p_guess = abs((x0 - r0) / q0); let p_guess = abs((x0 - r0) / q0);
println!("Recovered p: {}", p_guess); println!("Recovered p: {}", p_guess);
println!("Approximate GCD with noise_bits={}: {}", noise_bits, p_guess); println!(
"Approximate GCD with noise_bits={}: {}",
noise_bits, p_guess
);
p_guess p_guess
} }

122
src/bkz.rs Normal file
View File

@ -0,0 +1,122 @@
use lll_rs::vector::Vector;
use lll_rs::{l2::bigl2, matrix::Matrix as LLLMatrix, vector::BigVector};
use rug::Integer;
use std::cmp::min;
/// Gram-Schmidt orthogonalization
fn compute_gram_schmidt(
basis: &LLLMatrix<BigVector>,
) -> (Vec<BigVector>, Vec<Vec<Integer>>, Vec<Integer>) {
let n = basis.dimensions().1;
let mut orthogonal = Vec::with_capacity(n);
let mut mu = vec![vec![Integer::new(); n]; n];
let mut b = Vec::with_capacity(n);
for i in 0..n {
let mut gs = basis[i].clone();
for j in 0..i {
let numerator = inner_product(&gs, &orthogonal[j]);
let denominator: &Integer = &b[j];
mu[i][j] = numerator.clone() / denominator.clone();
gs = gs.sub(&orthogonal[j].clone().mulf(&mu[i][j].clone()));
}
let norm = inner_product(&gs, &gs);
b.push(norm);
orthogonal.push(gs);
}
(orthogonal, mu, b)
}
/// BigVectors product
fn inner_product(a: &BigVector, b: &BigVector) -> Integer {
assert!(a.dimension() == b.dimension());
let mut sum = Integer::new();
for i in 0..a.dimension() {
sum += &a[i] * &b[i];
}
sum
}
/// vectors orthogonal projection
fn project_block(
basis: &LLLMatrix<BigVector>,
start: usize,
end: usize,
orthogonal: &[BigVector],
mu: &[Vec<Integer>],
) -> LLLMatrix<BigVector> {
let mut projected = LLLMatrix::init(end - start, basis.dimensions().0);
for (idx, i) in (start..end).enumerate() {
let mut v = basis[i].clone();
for (j, ortho_vec) in orthogonal.iter().enumerate().take(start) {
let coeff = &mu[i][j];
v = v.sub(&ortho_vec.clone().mulf(&coeff.clone()));
}
projected[idx] = v;
}
projected
}
/// BKZ reduction using LLL as SVPoracle
pub fn bkz_reduce(
basis: &mut LLLMatrix<BigVector>,
block_size: usize,
delta: f64,
eta: f64,
max_rounds: usize,
) {
let n = basis.dimensions().1;
for _ in 0..max_rounds {
let (orthogonal, mu, _b) = compute_gram_schmidt(basis);
for k in 0..n {
let l = min(k + block_size, n);
if l - k < 2 {
continue;
}
let mut projected = project_block(basis, k, l, &orthogonal, &mu);
bigl2::lattice_reduce(&mut projected, delta, eta);
for (i_block, i) in (k..l).enumerate() {
basis[i] = projected[i_block].clone();
}
}
}
bigl2::lattice_reduce(basis, delta, eta);
}
#[cfg(test)]
mod tests {
use super::*;
use crate::matrix::Matrix;
use rug::Integer;
#[test]
fn test_bkz_reduction() {
let ciphertexts = vec![Integer::from(5), Integer::from(8), Integer::from(12)];
let noise_bits = 2;
let basis_matrix = Matrix::new_lattice(noise_bits, ciphertexts).unwrap();
let mut lll_matrix = basis_matrix.to_lll_matrix();
bkz_reduce(&mut lll_matrix, 5, 0.51, 0.75, 3);
let shortest_vector = &lll_matrix[0];
let mut norm = Integer::from(0);
for i in 0..=lll_matrix.dimensions().0 - 1 {
norm += &shortest_vector[i] * &shortest_vector[i];
}
assert!(
norm > Integer::from(0),
"Shortest vector should not be zero"
);
}
}

View File

@ -1,4 +1,5 @@
mod agcd; mod agcd;
mod bkz;
mod file; mod file;
mod lll; mod lll;
mod matrix; mod matrix;
@ -23,6 +24,9 @@ enum Commands {
Agcd { Agcd {
/// (default './input.txt') /// (default './input.txt')
path: Option<PathBuf>, path: Option<PathBuf>,
/// Algorithm variant to use (default: 0)
#[arg(short = 'a', long, default_value_t = 0u8)]
algorithm: u8,
}, },
} }
@ -30,11 +34,11 @@ fn main() -> std::io::Result<()> {
let cli = Cli::parse(); let cli = Cli::parse();
match &cli.command { match &cli.command {
Commands::Agcd { path } => { Commands::Agcd { path, algorithm } => {
let path = path.as_deref().unwrap_or(Path::new("./input.txt")); let path = path.as_deref().unwrap_or(Path::new("./input.txt"));
let input = parse_file(path)?; let input = parse_file(path)?;
agcd(input.numbers, input.noise_bits); agcd(input.numbers, input.noise_bits, *algorithm);
} }
} }
Ok(()) Ok(())