diff --git a/src/walsh/cli.py b/src/walsh/cli.py index 06848e0..0490e0e 100644 --- a/src/walsh/cli.py +++ b/src/walsh/cli.py @@ -2,7 +2,7 @@ from __future__ import annotations import sys -from .walsh import walsh_transform +from .walsh import walsh_transform, nonlinearity def main() -> None: @@ -12,6 +12,7 @@ def main() -> None: expr = input("Enter formula: ").strip() print(walsh_transform(expr)) + print("non-linearity:", nonlinearity(expr)) if __name__ == "__main__": diff --git a/src/walsh/walsh.py b/src/walsh/walsh.py index 4c6fc72..b8783df 100644 --- a/src/walsh/walsh.py +++ b/src/walsh/walsh.py @@ -66,3 +66,19 @@ def walsh_transform( coeffs.append(total) return coeffs + +def nonlinearity( + expr: str | Poly, + variables: Sequence[str] | None = None, +) -> int: + """ + NL(f) = 2^(n-1) - max_u |W_f(u)| / 2 + """ + poly = parse_boolean_expression(expr) if isinstance(expr, str) else expr + + vars_ = list(variables) if variables is not None else infer_variables(poly) + n = len(vars_) + + spectrum = walsh_transform(poly, vars_) + + return (1 << (n - 1)) - max(abs(w) for w in spectrum) // 2 diff --git a/tests/test_nonlinearity.py b/tests/test_nonlinearity.py new file mode 100644 index 0000000..459edcf --- /dev/null +++ b/tests/test_nonlinearity.py @@ -0,0 +1,13 @@ +from walsh.walsh import nonlinearity + + +def test_nonlinearity_linear_function(): + assert nonlinearity("x1 XOR x2") == 0 + + +def test_nonlinearity_or(): + assert nonlinearity("x1 OR x2") == 1 + + +def test_nonlinearity_majority(): + assert nonlinearity("x1*x2 XOR x1*x3 XOR x2*x3") == 2