matrix implementation change

This commit is contained in:
Sam Hadow 2025-05-23 22:37:15 +02:00
parent 98378115ca
commit c107afc342
3 changed files with 85 additions and 64 deletions

View File

@ -92,7 +92,6 @@ pub fn bkz_reduce(
bigl2::lattice_reduce(basis, delta, eta); bigl2::lattice_reduce(basis, delta, eta);
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;

View File

@ -1,7 +1,7 @@
use rug::ops::Pow; use rug::ops::Pow;
use rug::Integer;
use std::ops::{Index, IndexMut}; use std::ops::{Index, IndexMut};
use crate::vector::IntVector;
type Element = rug::Integer;
macro_rules! int { macro_rules! int {
($x:expr) => { ($x:expr) => {
@ -13,54 +13,71 @@ macro_rules! int {
pub struct Matrix { pub struct Matrix {
pub n: usize, // rows pub n: usize, // rows
pub m: usize, // columns pub m: usize, // columns
values: Vec<Element>, columns: Vec<IntVector>,
} }
impl Matrix { impl Matrix {
pub fn new(n: usize, m: usize, values: Vec<Element>) -> Option<Self> { pub fn new(n: usize, m: usize, values: Vec<Integer>) -> Option<Self> {
if n * m == values.len() { if n * m != values.len() {
Some(Matrix { n, m, values }) return None;
} else {
None
} }
let mut columns = vec![Vec::with_capacity(m); n];
for (i, value) in values.into_iter().enumerate() {
let col = i % n;
columns[col].push(value);
}
Some(Matrix {
n,
m,
columns: columns.into_iter().map(IntVector::from_vec).collect(),
})
} }
pub fn new_lattice(noise_bits: usize, ciphertexts: Vec<Element>) -> Option<Self> { pub fn new_lattice(noise_bits: usize, ciphertexts: Vec<Integer>) -> Option<Self> {
let n = ciphertexts.len(); let n = ciphertexts.len();
let mut values = Vec::with_capacity(n * n); let mut columns = vec![Vec::with_capacity(n); n];
// First row: [2^(noise_bits+1), ciphertexts[1], ..., ciphertexts[t]] columns[0].push(int!(2u64).pow((noise_bits + 1) as u32));
values.push(int!(2u64).pow(noise_bits as u32 + 1));
values.extend_from_slice(&ciphertexts[1..]);
// x0 on diagonal, 0 everywhere else
let x0 = &ciphertexts[0];
for i in 1..n { for i in 1..n {
let mut row = vec![int!(0); n]; columns[i].push(ciphertexts[i].clone());
row[i] = x0.clone();
values.extend(row);
} }
Matrix::new(n, n, values) for i in 1..n {
for j in 0..n {
if i == j {
columns[j].push(ciphertexts[0].clone());
} else {
columns[j].push(int!(0));
}
}
}
Some(Matrix {
n,
m: n,
columns: columns.into_iter().map(IntVector::from_vec).collect(),
})
} }
} }
impl Index<(usize, usize)> for Matrix { impl Index<(usize, usize)> for Matrix {
type Output = Element; type Output = Integer;
fn index(&self, index: (usize, usize)) -> &Self::Output { fn index(&self, index: (usize, usize)) -> &Self::Output {
if index.0 >= self.m || index.1 >= self.n { let (col, row) = index;
panic!(); if row >= self.n || col >= self.m {
panic!("Matrix index out of bounds");
} }
&self.values[(self.n * index.0) + index.1] &self.columns[row][col]
} }
} }
impl IndexMut<(usize, usize)> for Matrix { impl IndexMut<(usize, usize)> for Matrix {
fn index_mut(&mut self, index: (usize, usize)) -> &mut Self::Output { fn index_mut(&mut self, index: (usize, usize)) -> &mut Self::Output {
if index.0 >= self.m || index.1 >= self.n { let (col, row) = index;
panic!(); if row >= self.n || col >= self.m {
panic!("Matrix index out of bounds");
} }
&mut self.values[(self.n * index.0) + index.1] &mut self.columns[row][col]
} }
} }
@ -68,24 +85,24 @@ impl IndexMut<(usize, usize)> for Matrix {
mod tests { mod tests {
use super::*; use super::*;
use std::panic; use std::panic;
#[test] #[test]
fn simple_matrix() { fn simple_dimensions_and_index() {
assert_eq!( let m = Matrix::new(2, 2, vec![int!(1), int!(2), int!(3), int!(4)]).unwrap();
Matrix { assert_eq!(m.n, 2);
n: 2, assert_eq!(m.m, 2);
m: 2,
values: vec![int!(1), int!(2), int!(3), int!(4)], // values: [1,2,3,4]
}, // columns: [[1,3], [2,4]]
Matrix::new(2, 2, vec![int!(1), int!(2), int!(3), int!(4)]).unwrap() assert_eq!(m[(0, 0)], int!(1));
); assert_eq!(m[(0, 1)], int!(2));
assert!(Matrix::new(2, 2, vec![int!(1), int!(2)]).is_none()); assert_eq!(m[(1, 0)], int!(3));
assert_eq!(m[(1, 1)], int!(4));
} }
#[test] #[test]
fn indexes() { fn indexes_and_mutation() {
let mut m = Matrix::new(2, 2, vec![int!(1), int!(2), int!(3), int!(4)]).unwrap(); let mut m = Matrix::new(2, 2, vec![int!(1), int!(2), int!(3), int!(4)]).unwrap();
assert_eq!(m[(0, 0)], int!(1));
assert_eq!(m[(1, 0)], int!(3)); assert_eq!(m[(1, 0)], int!(3));
m[(1, 0)] = int!(5); m[(1, 0)] = int!(5);
@ -97,38 +114,43 @@ mod tests {
vec![int!(1), int!(2), int!(3), int!(4), int!(5), int!(6)], vec![int!(1), int!(2), int!(3), int!(4), int!(5), int!(6)],
) )
.unwrap(); .unwrap();
assert_eq!(m2[(0, 2)], int!(3)); assert_eq!(m2[(0, 2)], int!(3));
assert_eq!(m2[(1, 0)], int!(4)); assert_eq!(m2[(1, 0)], int!(4));
let result = panic::catch_unwind(|| {
let _ = m2[(0, 3)];
});
assert!(result.is_err(), "Expected panic on m2[(0, 3)]");
let result2 = panic::catch_unwind(|| { let result = panic::catch_unwind(|| { let _ = m2[(2, 0)]; });
let _ = m2[(2, 0)]; assert!(result.is_err(), "Expected panic on m2[(2, 0)]");
});
assert!(result2.is_err(), "Expected panic on m2[(2, 0)]"); let result2 = panic::catch_unwind(|| { let _ = m2[(0, 3)]; });
assert!(result2.is_err(), "Expected panic on m2[(0, 3)]");
} }
#[test] #[test]
fn test_new_lattice() { fn test_new_lattice_layout() {
let ciphertexts = vec![int!(5), int!(8), int!(12)]; let ciphertexts = vec![int!(5), int!(8), int!(12)];
let noise_bits = 2; let noise_bits = 2;
let lattice = Matrix::new_lattice(noise_bits, ciphertexts.clone()).unwrap();
let expected_values = vec![ assert_eq!(lattice.n, 3);
int!(8), assert_eq!(lattice.m, 3);
int!(8),
int!(12), // 1st column = [2^(noise+1), 0, 0] = [8,0,0]
int!(0), // 2nd column = [ciphertexts[1], ciphertexts[0], 0] = [8,5,0]
int!(5), // 3rd column = [ciphertexts[2], 0, ciphertexts[0]] = [12,0,5]
int!(0), let expected_flat = vec![
int!(0), int!(8), int!(8), int!(12),
int!(0), int!(0), int!(5), int!(0),
int!(5), int!(0), int!(0), int!(5),
]; ];
let lattice = Matrix::new_lattice(noise_bits, ciphertexts).unwrap(); let mut actual_flat = Vec::with_capacity(9);
assert_eq!(lattice.n, 3); for col in 0..lattice.m {
assert_eq!(lattice.values, expected_values); for row in 0..lattice.n {
actual_flat.push(lattice[(col, row)].clone());
}
}
assert_eq!(actual_flat, expected_flat);
} }
} }

View File

@ -10,7 +10,7 @@ macro_rules! int {
}; };
} }
#[derive(Clone)] #[derive(Clone, PartialEq)]
pub struct IntVector { pub struct IntVector {
elements: Vec<Integer>, elements: Vec<Integer>,
} }