module Data.Vector.Bit (
BitVector,
unpack, pack,
unpackInteger, packInteger, unpackInt, packInt,
pad, padMax, zipPad, trimLeading
)
where
import Data.Bits
import qualified Data.Vector.Unboxed as V
type BitVector = V.Vector Bool
pad :: Int -> BitVector -> BitVector
pad i v = v V.++ V.replicate (i V.length v) False
padMax :: BitVector -> BitVector -> (BitVector, BitVector)
padMax xs ys = (padlen xs, padlen ys)
where
padlen = pad $ max (V.length xs) (V.length ys)
zipPad :: BitVector -> BitVector -> V.Vector (Bool, Bool)
zipPad xs ys = uncurry V.zip (padMax xs ys)
trimLeading :: BitVector -> BitVector
trimLeading = V.reverse . V.dropWhile not . V.reverse
instance Num BitVector where
fromInteger = unpackInteger
as + bs = if cout then V.tail sums `V.snoc` True else V.tail sums
where
cout = V.last carries
(sums, carries) = V.unzip sumsAndCarries
sumsAndCarries = V.scanl' fullAdd (False, False) (zipPad as bs)
fullAdd (_, cin) (a, b) = ((a /= b) /= cin, (a && b) || (cin && (a /= b)))
as * bs = trimLeading (sum partials)
where
partials = zipWith shiftMult (V.toList as) [0 ..]
shiftMult True i = bs `shiftL` i
shiftMult False _ = V.empty
as bs = trimLeading $ V.take (V.length as') (rawSum + 1)
where
rawSum = as' + complement bs'
(as', bs') = padMax as bs
abs = id
signum v | V.null v = 0
| otherwise = 1
instance Bits BitVector where
(.&.) = V.zipWith (&&)
(.|.) = V.zipWith (||)
xor = V.zipWith (/=)
complement = V.map not
shiftL v i = V.replicate i False V.++ v
shiftR = flip V.drop
rotateR v i = high V.++ low
where (low, high) = V.splitAt i v
rotateL v i = high V.++ low
where (low, high) = V.splitAt (V.length v i) v
bitSize = V.length
isSigned = const False
unpack :: (Bits a) => a -> BitVector
unpack w = trimLeading $ V.generate (bitSize w) (testBit w)
pack :: (Bits a) => BitVector -> a
pack v = V.ifoldl' set 0 v
where
set w i True = w `setBit` i
set w _ _ = w
unpackInteger :: Integer -> BitVector
unpackInteger = V.unfoldr f
where
f (flip divMod 2 -> (0, 0)) = Nothing
f (flip divMod 2 -> (q, 0)) = Just (False, q)
f (flip divMod 2 -> (q, 1)) = Just (True, q)
f _ = error "unexpected remainder when unpacking"
packInteger :: BitVector -> Integer
packInteger = pack
unpackInt :: Int -> BitVector
unpackInt = unpack
packInt :: BitVector -> Int
packInt = pack