diff --git a/README.md b/README.md index 9798a8f..61ef784 100644 --- a/README.md +++ b/README.md @@ -40,3 +40,13 @@ output: ``` *term order may differ a little* + + +### walsh.py + +Walsh transform of a boolean expression. + +#### run: +``` +walsh +``` diff --git a/pyproject.toml b/pyproject.toml index 774c759..6412233 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,6 +7,7 @@ requires-python = ">=3.10" [project.scripts] anf = "anf.cli:main" +walsh = "walsh.cli:main" [tool.pytest.ini_options] testpaths = ["tests"] diff --git a/src/walsh/cli.py b/src/walsh/cli.py new file mode 100644 index 0000000..06848e0 --- /dev/null +++ b/src/walsh/cli.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +import sys + +from .walsh import walsh_transform + + +def main() -> None: + if len(sys.argv) > 1: + expr = " ".join(sys.argv[1:]) + else: + expr = input("Enter formula: ").strip() + + print(walsh_transform(expr)) + + +if __name__ == "__main__": + main() diff --git a/src/walsh/walsh.py b/src/walsh/walsh.py new file mode 100644 index 0000000..4c6fc72 --- /dev/null +++ b/src/walsh/walsh.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +import re +from typing import Sequence + +from anf.core import Poly +from anf.parser import Parser, tokenize + + +_VAR_RE = re.compile(r"[Xx](\d+)$") + + +def parse_boolean_expression(expr: str) -> Poly: + return Parser(tokenize(expr)).parse() + + +def _var_sort_key(name: str) -> tuple[int, int | str]: + m = _VAR_RE.fullmatch(name) + if m: + return (0, int(m.group(1))) + return (1, name) + + +def infer_variables(poly: Poly) -> list[str]: + names: set[str] = set() + for monomial in poly: + names.update(monomial) + return sorted(names, key=_var_sort_key) + + +def eval_poly(poly: Poly, assignment: dict[str, int]) -> int: + out = 0 + for monomial in poly: + term = 1 + for name in monomial: + term &= assignment[name] + if term == 0: + break + out ^= term + return out + + +def walsh_transform( + expr: str | Poly, + variables: Sequence[str] | None = None, +) -> list[int]: + """ + Compute the Walsh transform of a Boolean function. + Returns a list W where W[u] is the Walsh coefficient for mask u, + with u interpreted in the same variable order as variables. + """ + poly = parse_boolean_expression(expr) if isinstance(expr, str) else expr + + vars_ = list(variables) if variables is not None else infer_variables(poly) + n = len(vars_) + + coeffs: list[int] = [] + + for u in range(1 << n): + total = 0 + for x in range(1 << n): + assignment = {vars_[i]: (x >> i) & 1 for i in range(n)} + fx = eval_poly(poly, assignment) + ux = ((u & x).bit_count() & 1) + total += 1 if (fx ^ ux) == 0 else -1 + coeffs.append(total) + + return coeffs diff --git a/tests/test_walsh.py b/tests/test_walsh.py new file mode 100644 index 0000000..6ddc3dc --- /dev/null +++ b/tests/test_walsh.py @@ -0,0 +1,11 @@ +from walsh.walsh import walsh_transform + +def test_walsh_xor_two_variables(): + assert walsh_transform("x1 XOR x2") == [0, 0, 0, 4] + +def test_walsh_or_two_variables(): + assert walsh_transform("x1 OR x2") == [-2, 2, 2, 2] + + +def test_walsh_xor_three_variables(): + assert walsh_transform("x1 XOR x2 XOR x3") == [0, 0, 0, 0, 0, 0, 0, 8]