diff --git a/src/tea3/variable_search.py b/src/tea3/variable_search.py index 9a5d157..ecc2216 100644 --- a/src/tea3/variable_search.py +++ b/src/tea3/variable_search.py @@ -2,49 +2,75 @@ from itertools import product as iproduct from tea3.tea3model import Tea3Model from tea3.pretty_print import pretty_print + +def count_monomials(poly): + s = pretty_print(poly).strip() + if not s or s == "0": + return 0 + return sum(1 for term in s.split("+") if term.strip()) + + def apply_variable_change(model, coeffs): - orig_x = [list(model.v[i*8 : (i+1)*8]) for i in range(5)] + orig_x = [list(model.v[i * 8 : (i + 1) * 8]) for i in range(5)] orig_y = model.y_bits - subs = {} - for i in range(1, 5): for j in range(8): subs[orig_x[i][j]] = orig_y[i][j] - for j in range(8): new_expr = orig_y[0][j] for i, ai in enumerate(coeffs, start=1): if ai: new_expr = new_expr + orig_y[i][j] subs[orig_x[0][j]] = new_expr - new_R = [] for i in range(8): row = [poly.subs(subs) for poly in model.R_bits[i]] new_R.append(row) return new_R + def run_exhaustive(steps: int, target_reg: int = 0, target_bit: int = -1): model = Tea3Model() for _ in range(steps): model.step() - snapshot = model + if target_bit == -1: + bits = list(range(8)) + else: + if not (0 <= target_bit < 8): + raise ValueError("target_bit must be in [0, 7] or -1") + bits = [target_bit] + + best_coeffs: tuple | None = None + best_label: str = "" + best_total: int = -1 + best_R: list | None = None + for idx, coeffs in enumerate(iproduct([0, 1], repeat=4)): new_R = apply_variable_change(snapshot, coeffs) - label = "".join(map(str, coeffs)) print(f"\n[{idx:02d}] (a1,a2,a3,a4) = {label}") - if target_bit == -1: - bits = range(8) - else: - if not (0 <= target_bit < 8): - raise ValueError("target_bit must be in [0, 7] or -1") - bits = [target_bit] - + total = 0 for j in bits: poly = new_R[target_reg][j] - print(f" R[{target_reg}][{j}] = {pretty_print(poly)}") + pp = pretty_print(poly) + print(f" R[{target_reg}][{j}] = {pp}") + total += count_monomials(poly) + + if best_total == -1 or total < best_total: + best_total = total + best_coeffs = coeffs + best_label = label + best_R = new_R + + + print("\nBest variable change (fewest total monomials)") + print(f" (a1,a2,a3,a4) = ({', '.join(map(str, best_coeffs))})") + print(f" Total monomials: {best_total}") + print(" Polynomial:\n") + for j in bits: + poly = best_R[target_reg][j] + print(f"R[{target_reg}][{j}] = {pretty_print(poly)}")