module Data.BitStream
( BitStream
, tabulate
, tabulateFix
, tabulateM
, tabulateFixM
, index
, mapWithKey
, traverseWithKey
, not
, zipWithKey
, zipWithKeyM
, and
, or
) where
import Prelude hiding ((^), (*), div, mod, fromIntegral, not, and, or)
import Data.Bits
import Data.Foldable hiding (and, or)
import Data.Function (fix)
import Data.Functor.Identity
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector as V
import Unsafe.Coerce
newtype BitStream = BitStream { _unBitStream :: V.Vector (U.Vector Word) }
word2int :: Word -> Int
word2int = unsafeCoerce
int2word :: Int -> Word
int2word = unsafeCoerce
bits :: Int
bits = finiteBitSize (0 :: Word)
bitsLog :: Int
bitsLog = bits 1 countLeadingZeros (int2word bits)
tabulate :: (Word -> Bool) -> BitStream
tabulate f = runIdentity $ tabulateM (return . f)
tabulateM :: forall m. Monad m => (Word -> m Bool) -> m BitStream
tabulateM f = do
z <- tabulateW 0
zs <- V.generateM (bits bitsLog) tabulateU
return $ BitStream $ U.singleton z `V.cons` zs
where
tabulateU :: Int -> m (U.Vector Word)
tabulateU i = U.generateM ii (\j -> tabulateW (ii + j))
where
ii = 1 `shiftL` i
tabulateW :: Int -> m Word
tabulateW j = foldlM go zeroBits [0 .. bits 1]
where
jj = j `shiftL` bitsLog
go acc k = do
b <- f (int2word $ jj + k)
return $ if b then acc `setBit` k else acc
tabulateFix :: ((Word -> Bool) -> Word -> Bool) -> BitStream
tabulateFix uf = runIdentity $ tabulateFixM ((return .) . uf . (runIdentity .))
tabulateFixM :: forall m. Monad m => ((Word -> m Bool) -> Word -> m Bool) -> m BitStream
tabulateFixM uf = bs
where
bs :: m BitStream
bs = do
z <- tabulateW (fix uf) 0
zs <- V.generateM (bits bitsLog) tabulateU
return $ BitStream $ U.singleton z `V.cons` zs
tabulateU :: Int -> m (U.Vector Word)
tabulateU i = U.generateM ii (\j -> tabulateW (uf f) (ii + j))
where
ii = 1 `shiftL` i
iii = ii `shiftL` bitsLog
f k = do
bs' <- bs
if k < int2word iii then return (index bs' k) else uf f k
tabulateW :: (Word -> m Bool) -> Int -> m Word
tabulateW f j = foldlM go zeroBits [0 .. bits 1]
where
jj = j `shiftL` bitsLog
go acc k = do
b <- f (int2word $ jj + k)
return $ if b then acc `setBit` k else acc
index :: BitStream -> Word -> Bool
index (BitStream vus) i =
if sgm < 0 then indexU (V.unsafeHead vus) (word2int i)
else indexU (vus `V.unsafeIndex` (sgm + 1)) (word2int $ i int2word bits `shiftL` sgm)
where
sgm :: Int
sgm = finiteBitSize i 1 bitsLog countLeadingZeros i
indexU :: U.Vector Word -> Int -> Bool
indexU vec j = testBit (vec `U.unsafeIndex` jHi) jLo
where
jHi = j `shiftR` bitsLog
jLo = j .&. (bits 1)
not :: BitStream -> BitStream
not (BitStream vus) = BitStream $ V.map (U.map (maxBound )) vus
mapWithKey :: (Word -> Bool -> Bool) -> BitStream -> BitStream
mapWithKey f = runIdentity . traverseWithKey ((return .) . f)
traverseWithKey :: forall m. Monad m => (Word -> Bool -> m Bool) -> BitStream -> m BitStream
traverseWithKey f (BitStream bs) = BitStream <$> V.imapM g bs
where
g :: Int -> U.Vector Word -> m (U.Vector Word)
g 0 = U.imapM h
g logOffset = U.imapM (h . (`shiftL` bitsLog) . (+ offset))
where
offset = 1 `shiftL` (logOffset 1)
h :: Int -> Word -> m Word
h offset w = foldlM go zeroBits [0 .. bits 1]
where
go acc k = do
b <- f (int2word $ offset + k) (testBit w k)
return $ if b then acc `setBit` k else acc
and :: BitStream -> BitStream -> BitStream
and (BitStream vus) (BitStream wus) = BitStream $ V.zipWith (U.zipWith (.&.)) vus wus
or :: BitStream -> BitStream -> BitStream
or (BitStream vus) (BitStream wus) = BitStream $ V.zipWith (U.zipWith (.|.)) vus wus
zipWithKey :: (Word -> Bool -> Bool -> Bool) -> BitStream -> BitStream -> BitStream
zipWithKey f = (runIdentity .) . zipWithKeyM (((return .) .) . f)
zipWithKeyM :: forall m. Monad m => (Word -> Bool -> Bool -> m Bool) -> BitStream -> BitStream -> m BitStream
zipWithKeyM f (BitStream bs1) (BitStream bs2) = BitStream <$> V.izipWithM g bs1 bs2
where
g :: Int -> U.Vector Word -> U.Vector Word -> m (U.Vector Word)
g 0 = U.izipWithM h
g logOffset = U.izipWithM (h . (`shiftL` bitsLog) . (+ offset))
where
offset = 1 `shiftL` (logOffset 1)
h :: Int -> Word -> Word -> m Word
h offset w1 w2 = foldlM go zeroBits [0 .. bits 1]
where
go acc k = do
b <- f (int2word $ offset + k) (testBit w1 k) (testBit w2 k)
return $ if b then acc `setBit` k else acc