97 lines
3.3 KiB
Haskell
97 lines
3.3 KiB
Haskell
module LinearAlgebra.GF2
|
|
( -- * Types
|
|
BitVec
|
|
, Mask
|
|
-- Conversions
|
|
, fromBools
|
|
, toBools
|
|
-- Gaussian elimination
|
|
, gaussianEliminationMask -- :: nCols -> [BitVec] -> Maybe Mask
|
|
, gaussianEliminationIndices -- :: nCols -> [BitVec] -> Maybe [Int]
|
|
, maskToIndicesInt -- :: Mask -> [Int]
|
|
) where
|
|
|
|
import Data.Bits
|
|
import Data.List (findIndex, foldl')
|
|
import Data.Maybe (listToMaybe, fromMaybe)
|
|
|
|
-- vector over GF(2) represented as an Integer bitmask. The bit i corresponding to the column i.
|
|
type BitVec = Integer
|
|
|
|
-- Mask to select rows. bit i at 1 means original row i included.
|
|
type Mask = Integer
|
|
|
|
-- Conversions
|
|
|
|
-- Convert a list of Bool to a BitVec.
|
|
-- Example: fromBools [True, False, True] == 0b101 == 5
|
|
fromBools :: [Bool] -> BitVec
|
|
fromBools bs = foldl' (\acc (i, b) -> if b then setBit acc i else acc) 0 (zip [0..] bs)
|
|
|
|
-- Convert a BitVec to a list of Bool of length nCols.
|
|
toBools :: Int -> BitVec -> [Bool]
|
|
toBools nCols v = [ testBit v i | i <- [0 .. nCols - 1] ]
|
|
|
|
-- Gaussian elimination over GF(2) using Integers
|
|
-- Gaussian elimination on a list of BitVec rows (each row is an Integer bitmask of length nCols) and return a mask (as Integer) whose set bits indicate which input rows XOR to produce a vector full of 0 (or Nothing if not enough rows to find a combination).
|
|
gaussianEliminationMask :: Int -> [BitVec] -> Maybe Mask
|
|
gaussianEliminationMask nCols rows
|
|
| null rows = Nothing
|
|
| otherwise = findZeroMask (eliminate 0 augmented)
|
|
where
|
|
m :: Int
|
|
m = length rows
|
|
|
|
-- initial mask = 1 << i
|
|
augmented :: [(BitVec, Mask)]
|
|
augmented = zip rows (map (\i -> bit i) [0..m-1])
|
|
|
|
-- Top level elimination
|
|
eliminate :: Int -> [(BitVec, Mask)] -> [(BitVec, Mask)]
|
|
eliminate col mat
|
|
| col >= nCols = mat
|
|
| otherwise =
|
|
case findPivot col mat of
|
|
Nothing -> eliminate (col + 1) mat
|
|
Just pivotIdx ->
|
|
let (pivotRow, rest) = removeAt pivotIdx mat
|
|
(pvVec, pvMask) = pivotRow
|
|
rest' = map (xorIfHasBit pvVec pvMask col) rest
|
|
in eliminate (col + 1) (pivotRow : rest')
|
|
|
|
findPivot :: Int -> [(BitVec, Mask)] -> Maybe Int
|
|
findPivot col mat = findIndex (\(v,_) -> testBit v col) mat
|
|
|
|
xorIfHasBit :: BitVec -> Mask -> Int -> (BitVec, Mask) -> (BitVec, Mask)
|
|
xorIfHasBit pvVec pvMask col (v, mask)
|
|
| testBit v col = (v `xor` pvVec, mask `xor` pvMask)
|
|
| otherwise = (v, mask)
|
|
|
|
findZeroMask :: [(BitVec, Mask)] -> Maybe Mask
|
|
findZeroMask mat = fmap snd (listToMaybe (filter (\(v,_) -> v == 0) mat))
|
|
|
|
|
|
-- wrapper
|
|
gaussianEliminationIndices :: Int -> [BitVec] -> Maybe [Int]
|
|
gaussianEliminationIndices nCols rows = fmap maskToIndicesInt (gaussianEliminationMask nCols rows)
|
|
|
|
-- helpers
|
|
|
|
maskToIndicesInt :: Mask -> [Int]
|
|
maskToIndicesInt mask = go 0 mask []
|
|
where
|
|
go :: Int -> Mask -> [Int] -> [Int]
|
|
go _ 0 acc = reverse acc
|
|
go i m acc
|
|
| testBit m 0 = go (i+1) (shiftR m 1) (i:acc)
|
|
| otherwise = go (i+1) (shiftR m 1) acc
|
|
|
|
-- remove element at index i from list
|
|
-- Returns (element, list-without-element)
|
|
removeAt :: Int -> [a] -> (a, [a])
|
|
removeAt i xs =
|
|
let (front, rest) = splitAt i xs
|
|
in case rest of
|
|
(y:ys) -> (y, front ++ ys)
|
|
[] -> error "removeAt: index out of bounds"
|