outputs best variable change

This commit is contained in:
2026-06-08 09:56:32 +02:00
parent 6827515739
commit 8ba3ebf1a1
+38 -12
View File
@@ -2,49 +2,75 @@ from itertools import product as iproduct
from tea3.tea3model import Tea3Model from tea3.tea3model import Tea3Model
from tea3.pretty_print import pretty_print from tea3.pretty_print import pretty_print
def count_monomials(poly):
s = pretty_print(poly).strip()
if not s or s == "0":
return 0
return sum(1 for term in s.split("+") if term.strip())
def apply_variable_change(model, coeffs): def apply_variable_change(model, coeffs):
orig_x = [list(model.v[i * 8 : (i + 1) * 8]) for i in range(5)] orig_x = [list(model.v[i * 8 : (i + 1) * 8]) for i in range(5)]
orig_y = model.y_bits orig_y = model.y_bits
subs = {} subs = {}
for i in range(1, 5): for i in range(1, 5):
for j in range(8): for j in range(8):
subs[orig_x[i][j]] = orig_y[i][j] subs[orig_x[i][j]] = orig_y[i][j]
for j in range(8): for j in range(8):
new_expr = orig_y[0][j] new_expr = orig_y[0][j]
for i, ai in enumerate(coeffs, start=1): for i, ai in enumerate(coeffs, start=1):
if ai: if ai:
new_expr = new_expr + orig_y[i][j] new_expr = new_expr + orig_y[i][j]
subs[orig_x[0][j]] = new_expr subs[orig_x[0][j]] = new_expr
new_R = [] new_R = []
for i in range(8): for i in range(8):
row = [poly.subs(subs) for poly in model.R_bits[i]] row = [poly.subs(subs) for poly in model.R_bits[i]]
new_R.append(row) new_R.append(row)
return new_R return new_R
def run_exhaustive(steps: int, target_reg: int = 0, target_bit: int = -1): def run_exhaustive(steps: int, target_reg: int = 0, target_bit: int = -1):
model = Tea3Model() model = Tea3Model()
for _ in range(steps): for _ in range(steps):
model.step() model.step()
snapshot = model snapshot = model
for idx, coeffs in enumerate(iproduct([0, 1], repeat=4)):
new_R = apply_variable_change(snapshot, coeffs)
label = "".join(map(str, coeffs))
print(f"\n[{idx:02d}] (a1,a2,a3,a4) = {label}")
if target_bit == -1: if target_bit == -1:
bits = range(8) bits = list(range(8))
else: else:
if not (0 <= target_bit < 8): if not (0 <= target_bit < 8):
raise ValueError("target_bit must be in [0, 7] or -1") raise ValueError("target_bit must be in [0, 7] or -1")
bits = [target_bit] bits = [target_bit]
best_coeffs: tuple | None = None
best_label: str = ""
best_total: int = -1
best_R: list | None = None
for idx, coeffs in enumerate(iproduct([0, 1], repeat=4)):
new_R = apply_variable_change(snapshot, coeffs)
label = "".join(map(str, coeffs))
print(f"\n[{idx:02d}] (a1,a2,a3,a4) = {label}")
total = 0
for j in bits: for j in bits:
poly = new_R[target_reg][j] poly = new_R[target_reg][j]
pp = pretty_print(poly)
print(f" R[{target_reg}][{j}] = {pp}")
total += count_monomials(poly)
if best_total == -1 or total < best_total:
best_total = total
best_coeffs = coeffs
best_label = label
best_R = new_R
print("\nBest variable change (fewest total monomials)")
print(f" (a1,a2,a3,a4) = ({', '.join(map(str, best_coeffs))})")
print(f" Total monomials: {best_total}")
print(" Polynomial:\n")
for j in bits:
poly = best_R[target_reg][j]
print(f"R[{target_reg}][{j}] = {pretty_print(poly)}") print(f"R[{target_reg}][{j}] = {pretty_print(poly)}")