walsh transform
This commit is contained in:
@@ -40,3 +40,13 @@ output:
|
|||||||
```
|
```
|
||||||
|
|
||||||
*term order may differ a little*
|
*term order may differ a little*
|
||||||
|
|
||||||
|
|
||||||
|
### walsh.py
|
||||||
|
|
||||||
|
Walsh transform of a boolean expression.
|
||||||
|
|
||||||
|
#### run:
|
||||||
|
```
|
||||||
|
walsh
|
||||||
|
```
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ requires-python = ">=3.10"
|
|||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
anf = "anf.cli:main"
|
anf = "anf.cli:main"
|
||||||
|
walsh = "walsh.cli:main"
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
testpaths = ["tests"]
|
testpaths = ["tests"]
|
||||||
|
|||||||
@@ -0,0 +1,18 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from .walsh import walsh_transform
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
if len(sys.argv) > 1:
|
||||||
|
expr = " ".join(sys.argv[1:])
|
||||||
|
else:
|
||||||
|
expr = input("Enter formula: ").strip()
|
||||||
|
|
||||||
|
print(walsh_transform(expr))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -0,0 +1,68 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import Sequence
|
||||||
|
|
||||||
|
from anf.core import Poly
|
||||||
|
from anf.parser import Parser, tokenize
|
||||||
|
|
||||||
|
|
||||||
|
_VAR_RE = re.compile(r"[Xx](\d+)$")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_boolean_expression(expr: str) -> Poly:
|
||||||
|
return Parser(tokenize(expr)).parse()
|
||||||
|
|
||||||
|
|
||||||
|
def _var_sort_key(name: str) -> tuple[int, int | str]:
|
||||||
|
m = _VAR_RE.fullmatch(name)
|
||||||
|
if m:
|
||||||
|
return (0, int(m.group(1)))
|
||||||
|
return (1, name)
|
||||||
|
|
||||||
|
|
||||||
|
def infer_variables(poly: Poly) -> list[str]:
|
||||||
|
names: set[str] = set()
|
||||||
|
for monomial in poly:
|
||||||
|
names.update(monomial)
|
||||||
|
return sorted(names, key=_var_sort_key)
|
||||||
|
|
||||||
|
|
||||||
|
def eval_poly(poly: Poly, assignment: dict[str, int]) -> int:
|
||||||
|
out = 0
|
||||||
|
for monomial in poly:
|
||||||
|
term = 1
|
||||||
|
for name in monomial:
|
||||||
|
term &= assignment[name]
|
||||||
|
if term == 0:
|
||||||
|
break
|
||||||
|
out ^= term
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def walsh_transform(
|
||||||
|
expr: str | Poly,
|
||||||
|
variables: Sequence[str] | None = None,
|
||||||
|
) -> list[int]:
|
||||||
|
"""
|
||||||
|
Compute the Walsh transform of a Boolean function.
|
||||||
|
Returns a list W where W[u] is the Walsh coefficient for mask u,
|
||||||
|
with u interpreted in the same variable order as variables.
|
||||||
|
"""
|
||||||
|
poly = parse_boolean_expression(expr) if isinstance(expr, str) else expr
|
||||||
|
|
||||||
|
vars_ = list(variables) if variables is not None else infer_variables(poly)
|
||||||
|
n = len(vars_)
|
||||||
|
|
||||||
|
coeffs: list[int] = []
|
||||||
|
|
||||||
|
for u in range(1 << n):
|
||||||
|
total = 0
|
||||||
|
for x in range(1 << n):
|
||||||
|
assignment = {vars_[i]: (x >> i) & 1 for i in range(n)}
|
||||||
|
fx = eval_poly(poly, assignment)
|
||||||
|
ux = ((u & x).bit_count() & 1)
|
||||||
|
total += 1 if (fx ^ ux) == 0 else -1
|
||||||
|
coeffs.append(total)
|
||||||
|
|
||||||
|
return coeffs
|
||||||
@@ -0,0 +1,11 @@
|
|||||||
|
from walsh.walsh import walsh_transform
|
||||||
|
|
||||||
|
def test_walsh_xor_two_variables():
|
||||||
|
assert walsh_transform("x1 XOR x2") == [0, 0, 0, 4]
|
||||||
|
|
||||||
|
def test_walsh_or_two_variables():
|
||||||
|
assert walsh_transform("x1 OR x2") == [-2, 2, 2, 2]
|
||||||
|
|
||||||
|
|
||||||
|
def test_walsh_xor_three_variables():
|
||||||
|
assert walsh_transform("x1 XOR x2 XOR x3") == [0, 0, 0, 0, 0, 0, 0, 8]
|
||||||
Reference in New Issue
Block a user