module Data.Vector.Bit (
BitVector,
unpack, pack,
unpackInteger, packInteger, unpackInt, packInt,
pad, padMax, zipPad, trimLeading, (==~)
)
where
import Data.Bits
import Data.Function
import qualified Data.Vector.Unboxed as V
type BitVector = V.Vector Bool
pad :: Int -> BitVector -> BitVector
pad i v = v V.++ V.replicate (i V.length v) False
padMax :: BitVector -> BitVector -> (BitVector, BitVector)
padMax xs ys = (padlen xs, padlen ys)
where
padlen = pad $ max (V.length xs) (V.length ys)
zipPad :: BitVector -> BitVector -> V.Vector (Bool, Bool)
zipPad xs ys = uncurry V.zip (padMax xs ys)
zipPadWith :: (Bool -> Bool -> Bool) -> BitVector -> BitVector -> BitVector
zipPadWith f xs ys = uncurry (V.zipWith f) (padMax xs ys)
trimLeading :: BitVector -> BitVector
trimLeading = V.reverse . V.dropWhile not . V.reverse
infix 4 ==~
(==~) :: BitVector -> BitVector -> Bool
(==~) = (==) `on` trimLeading
instance Num BitVector where
fromInteger = unpackInteger
as + bs = if cout then V.tail sums `V.snoc` True else V.tail sums
where
cout = V.last carries
(sums, carries) = V.unzip sumsAndCarries
sumsAndCarries = V.scanl' fullAdd (False, False) (zipPad as bs)
fullAdd (_, cin) (a, b) = ((a /= b) /= cin, (a && b) || (cin && (a /= b)))
as * bs = trimLeading (sum partials)
where
partials = zipWith shiftMult (V.toList as) [0 ..]
shiftMult True i = bs `shiftL` i
shiftMult False _ = V.empty
as bs = trimLeading $ V.take (V.length as') (rawSum + 1)
where
rawSum = as' + complement bs'
(as', bs') = padMax as bs
abs = id
signum v | V.null v = 0
| otherwise = 1
instance Bits BitVector where
(.&.) = zipPadWith (&&)
(.|.) = zipPadWith (||)
xor = zipPadWith (/=)
complement = V.map not
shiftL v i = V.replicate i False V.++ v
shiftR = flip V.drop
rotateR v i = high V.++ low
where (low, high) = V.splitAt i v
rotateL v i = high V.++ low
where (low, high) = V.splitAt (V.length v i) v
setBit v i | i < V.length v = V.unsafeUpd v [(i, True)]
| otherwise = V.generate (i+1) f
where f j | i == j = True
| otherwise = testBit v j
clearBit v i | i < V.length v = V.unsafeUpd v [(i, False)]
| otherwise = v
testBit v i
| Just b <- v V.!? i = b
| otherwise = False
bit n = V.generate (n+1) (== n)
bitSizeMaybe = Just . V.length
bitSize = finiteBitSize
isSigned = const False
popCount = V.foldl' (\x b -> if b then x+1 else x) 0
instance FiniteBits BitVector where
finiteBitSize = V.length
unpack :: (FiniteBits a) => a -> BitVector
unpack w = trimLeading $ V.generate (finiteBitSize w) (testBit w)
pack :: (Num a, Bits a) => BitVector -> a
pack v = V.ifoldl' set 0 v
where
set w i True = w `setBit` i
set w _ _ = w
unpackInteger :: Integer -> BitVector
unpackInteger = V.unfoldr f
where
f (flip divMod 2 -> (0, 0)) = Nothing
f (flip divMod 2 -> (q, 0)) = Just (False, q)
f (flip divMod 2 -> (q, 1)) = Just (True, q)
f _ = error "unexpected remainder when unpacking"
packInteger :: BitVector -> Integer
packInteger = pack
unpackInt :: Int -> BitVector
unpackInt = unpack
packInt :: BitVector -> Int
packInt = pack