walsh transform

This commit is contained in:
2026-05-18 15:05:09 +02:00
parent 3f93c074ef
commit 84fda8757c
5 changed files with 108 additions and 0 deletions
+10
View File
@@ -40,3 +40,13 @@ output:
``` ```
*term order may differ a little* *term order may differ a little*
### walsh.py
Walsh transform of a boolean expression.
#### run:
```
walsh
```
+1
View File
@@ -7,6 +7,7 @@ requires-python = ">=3.10"
[project.scripts] [project.scripts]
anf = "anf.cli:main" anf = "anf.cli:main"
walsh = "walsh.cli:main"
[tool.pytest.ini_options] [tool.pytest.ini_options]
testpaths = ["tests"] testpaths = ["tests"]
+18
View File
@@ -0,0 +1,18 @@
from __future__ import annotations
import sys
from .walsh import walsh_transform
def main() -> None:
if len(sys.argv) > 1:
expr = " ".join(sys.argv[1:])
else:
expr = input("Enter formula: ").strip()
print(walsh_transform(expr))
if __name__ == "__main__":
main()
+68
View File
@@ -0,0 +1,68 @@
from __future__ import annotations
import re
from typing import Sequence
from anf.core import Poly
from anf.parser import Parser, tokenize
_VAR_RE = re.compile(r"[Xx](\d+)$")
def parse_boolean_expression(expr: str) -> Poly:
return Parser(tokenize(expr)).parse()
def _var_sort_key(name: str) -> tuple[int, int | str]:
m = _VAR_RE.fullmatch(name)
if m:
return (0, int(m.group(1)))
return (1, name)
def infer_variables(poly: Poly) -> list[str]:
names: set[str] = set()
for monomial in poly:
names.update(monomial)
return sorted(names, key=_var_sort_key)
def eval_poly(poly: Poly, assignment: dict[str, int]) -> int:
out = 0
for monomial in poly:
term = 1
for name in monomial:
term &= assignment[name]
if term == 0:
break
out ^= term
return out
def walsh_transform(
expr: str | Poly,
variables: Sequence[str] | None = None,
) -> list[int]:
"""
Compute the Walsh transform of a Boolean function.
Returns a list W where W[u] is the Walsh coefficient for mask u,
with u interpreted in the same variable order as variables.
"""
poly = parse_boolean_expression(expr) if isinstance(expr, str) else expr
vars_ = list(variables) if variables is not None else infer_variables(poly)
n = len(vars_)
coeffs: list[int] = []
for u in range(1 << n):
total = 0
for x in range(1 << n):
assignment = {vars_[i]: (x >> i) & 1 for i in range(n)}
fx = eval_poly(poly, assignment)
ux = ((u & x).bit_count() & 1)
total += 1 if (fx ^ ux) == 0 else -1
coeffs.append(total)
return coeffs
+11
View File
@@ -0,0 +1,11 @@
from walsh.walsh import walsh_transform
def test_walsh_xor_two_variables():
assert walsh_transform("x1 XOR x2") == [0, 0, 0, 4]
def test_walsh_or_two_variables():
assert walsh_transform("x1 OR x2") == [-2, 2, 2, 2]
def test_walsh_xor_three_variables():
assert walsh_transform("x1 XOR x2 XOR x3") == [0, 0, 0, 0, 0, 0, 0, 8]