Files
haskell-math/src/LinearAlgebra/GF2.hs
2026-01-16 19:27:16 +01:00

97 lines
3.3 KiB
Haskell

module LinearAlgebra.GF2
( -- * Types
BitVec
, Mask
-- Conversions
, fromBools
, toBools
-- Gaussian elimination
, gaussianEliminationMask -- :: nCols -> [BitVec] -> Maybe Mask
, gaussianEliminationIndices -- :: nCols -> [BitVec] -> Maybe [Int]
, maskToIndicesInt -- :: Mask -> [Int]
) where
import Data.Bits
import Data.List (findIndex, foldl')
import Data.Maybe (listToMaybe, fromMaybe)
-- vector over GF(2) represented as an Integer bitmask. The bit i corresponding to the column i.
type BitVec = Integer
-- Mask to select rows. bit i at 1 means original row i included.
type Mask = Integer
-- Conversions
-- Convert a list of Bool to a BitVec.
-- Example: fromBools [True, False, True] == 0b101 == 5
fromBools :: [Bool] -> BitVec
fromBools bs = foldl' (\acc (i, b) -> if b then setBit acc i else acc) 0 (zip [0..] bs)
-- Convert a BitVec to a list of Bool of length nCols.
toBools :: Int -> BitVec -> [Bool]
toBools nCols v = [ testBit v i | i <- [0 .. nCols - 1] ]
-- Gaussian elimination over GF(2) using Integers
-- Gaussian elimination on a list of BitVec rows (each row is an Integer bitmask of length nCols) and return a mask (as Integer) whose set bits indicate which input rows XOR to produce a vector full of 0 (or Nothing if not enough rows to find a combination).
gaussianEliminationMask :: Int -> [BitVec] -> Maybe Mask
gaussianEliminationMask nCols rows
| null rows = Nothing
| otherwise = findZeroMask (eliminate 0 augmented)
where
m :: Int
m = length rows
-- initial mask = 1 << i
augmented :: [(BitVec, Mask)]
augmented = zip rows (map (\i -> bit i) [0..m-1])
-- Top level elimination
eliminate :: Int -> [(BitVec, Mask)] -> [(BitVec, Mask)]
eliminate col mat
| col >= nCols = mat
| otherwise =
case findPivot col mat of
Nothing -> eliminate (col + 1) mat
Just pivotIdx ->
let (pivotRow, rest) = removeAt pivotIdx mat
(pvVec, pvMask) = pivotRow
rest' = map (xorIfHasBit pvVec pvMask col) rest
in eliminate (col + 1) (pivotRow : rest')
findPivot :: Int -> [(BitVec, Mask)] -> Maybe Int
findPivot col mat = findIndex (\(v,_) -> testBit v col) mat
xorIfHasBit :: BitVec -> Mask -> Int -> (BitVec, Mask) -> (BitVec, Mask)
xorIfHasBit pvVec pvMask col (v, mask)
| testBit v col = (v `xor` pvVec, mask `xor` pvMask)
| otherwise = (v, mask)
findZeroMask :: [(BitVec, Mask)] -> Maybe Mask
findZeroMask mat = fmap snd (listToMaybe (filter (\(v,_) -> v == 0) mat))
-- wrapper
gaussianEliminationIndices :: Int -> [BitVec] -> Maybe [Int]
gaussianEliminationIndices nCols rows = fmap maskToIndicesInt (gaussianEliminationMask nCols rows)
-- helpers
maskToIndicesInt :: Mask -> [Int]
maskToIndicesInt mask = go 0 mask []
where
go :: Int -> Mask -> [Int] -> [Int]
go _ 0 acc = reverse acc
go i m acc
| testBit m 0 = go (i+1) (shiftR m 1) (i:acc)
| otherwise = go (i+1) (shiftR m 1) acc
-- remove element at index i from list
-- Returns (element, list-without-element)
removeAt :: Int -> [a] -> (a, [a])
removeAt i xs =
let (front, rest) = splitAt i xs
in case rest of
(y:ys) -> (y, front ++ ys)
[] -> error "removeAt: index out of bounds"