From 0e1c67b7aa614b4990bc88d6a3402fbea5cc7ba4 Mon Sep 17 00:00:00 2001 From: Sam Hadow Date: Tue, 21 Apr 2026 16:36:13 +0200 Subject: [PATCH] abstract pool --- src/tea3/cli.py | 2 +- src/tea3/tea3model.py | 107 ++++++++++++++++++++++-------------------- 2 files changed, 56 insertions(+), 53 deletions(-) diff --git a/src/tea3/cli.py b/src/tea3/cli.py index 5590f86..a4b965a 100644 --- a/src/tea3/cli.py +++ b/src/tea3/cli.py @@ -29,7 +29,7 @@ def main() -> None: print(f"\nRunning {steps} step(s), watching R[{reg}][{bit}]") print("-" * 50) - model = Tea3Model(max_steps=steps) + model = Tea3Model() for i in range(steps): model.step() diff --git a/src/tea3/tea3model.py b/src/tea3/tea3model.py index 8970f0d..743c391 100644 --- a/src/tea3/tea3model.py +++ b/src/tea3/tea3model.py @@ -1,11 +1,13 @@ from sage.all import GF, BooleanPolynomialRing +from functools import reduce +from operator import mul from tea3.constants import TEA3_SBOX, T_F1, T_F2 from tea3.pretty_print import pretty_print, pretty_print_vec class Tea3Model: - def __init__(self, max_steps=20): + def __init__(self, max_abstract=40960): self.F = GF(2) self.step_count = 0 @@ -13,71 +15,72 @@ class Tea3Model: [f"x{i}{j}" for i in range(5) for j in range(8)] + [f"r{i}{j}" for i in range(5) for j in range(8)] + [f"R{i}{j}" for i in range(8) for j in range(8)] + - [f"f{s}_{i}{j}" for s in range(max_steps) for i in range(8) for j in range(8)] + [f"g{n}" for n in range(max_abstract)] ) name_string = ",".join(names) self.S = BooleanPolynomialRing(len(names), name_string) self.v = self.S.gens() - self.x_bits = [list(self.v[i*8:(i+1)*8]) for i in range(5)] + self.x_bits = [list(self.v[i*8:(i+1)*8]) for i in range(5)] self.r_bits = [list(self.v[40 + i*8 : 40 + (i+1)*8]) for i in range(5)] self.R_bits = [list(self.v[80 + i*8 : 80 + (i+1)*8]) for i in range(8)] - # Abstract variables - base = 80 + 64 - self.fR_bits = [ - [list(self.v[base + s*64 + i*8 : base + s*64 + i*8 + 8]) - for i in range(8)] - for s in range(max_steps) - ] + self._g_pool = list(self.v[80 + 64:]) + self._g_index = 0 + self.abstractions = {} - def _split_poly(self, poly): - """ - Split a polynomial into: - - R_f_part: monomials involving only 'R' or 'f' (abstract) variables - - xr_part: monomials involving 'x' or 'r' variables - - constant term is grouped with R_f_part when R_f_part is non-zero - """ - zero = self.S.zero() - R_f_part = zero - xr_part = zero - - has_const = bool(poly.constant_coefficient()) - - for monom in poly: - vars_in_term = monom.variables() - if not vars_in_term: - continue - families = {str(v)[0] for v in vars_in_term} - monom_poly = self.S(monom) - if families <= {'R', 'f'}: - R_f_part += monom_poly - else: - xr_part += monom_poly - - if has_const: - if R_f_part != zero: - R_f_part += self.S.one() - else: - xr_part += self.S.one() - - return R_f_part, xr_part + def _alloc(self, step, i, j, xr_key): + if self._g_index >= len(self._g_pool): + raise RuntimeError("Abstract variable pool exhausted — increase max_abstract.") + var = self._g_pool[self._g_index] + self.abstractions[(step, i, j, xr_key)] = (var, self._g_index) + self._g_index += 1 + return var def _abstract_R(self): - """ - Replace the R/f-dependent part of every R_bits[i][j] with an - abstract variable f{step}_{i}{j}, leaving only x and r terms explicit. - """ - s = self.step_count + s = self.step_count + one = self.S.one() + zero = self.S.zero() + for i in range(8): for j in range(8): - R_f_part, xr_part = self._split_poly(self.R_bits[i][j]) - if R_f_part != self.S.zero(): - self.R_bits[i][j] = self.fR_bits[s][i][j] + xr_part - else: - self.R_bits[i][j] = xr_part + poly = self.R_bits[i][j] + + groups = {} + pure_xr = zero + const = one if bool(poly.constant_coefficient()) else zero + + for monom in poly: + term_vars = monom.variables() + if not term_vars: + continue + + xr_vars = [v for v in term_vars if str(v)[0] in ('x', 'r')] + Rf_vars = [v for v in term_vars if str(v)[0] in ('R', 'f', 'g')] # 'g' added + + xr_mono = reduce(mul, (self.S(v) for v in xr_vars), one) + Rf_mono = reduce(mul, (self.S(v) for v in Rf_vars), one) + xr_key = frozenset(str(v) for v in xr_vars) + + if not Rf_vars: + pure_xr += xr_mono + else: + if xr_key not in groups: + groups[xr_key] = {'xr_mono': xr_mono, 'Rf_sum': zero} + groups[xr_key]['Rf_sum'] += Rf_mono + + result = pure_xr + const + + for xr_key, bucket in groups.items(): + Rf_sum = bucket['Rf_sum'] + xr_mono = bucket['xr_mono'] + if Rf_sum == zero: + continue + g = self._alloc(s, i, j, xr_key) + result += xr_mono * g + + self.R_bits[i][j] = result def step(self): R = self.R_bits.copy()