{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
#if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ >= 702
{-# LANGUAGE Trustworthy #-}
#else
#define safe
#endif
module Data.Vector.Unboxed.Bit
( module Data.Bit
, module U
, wordSize
, wordLength
, fromWords
, toWords
, indexWord
, pad
, padWith
, zipWords
, union
, unions
, intersection
, intersections
, difference
, symDiff
, invert
, select
, selectBits
, exclude
, excludeBits
, countBits
, listBits
, and
, or
, any
, anyBits
, all
, allBits
, reverse
, first
, findIndex
) where
import safe Control.Monad
import Control.Monad.ST
import safe Data.Bit
import safe Data.Bit.Internal
import safe Data.Bits
import safe qualified Data.List as L
import qualified Data.Vector.Generic as V
import qualified Data.Vector.Generic.Mutable as MV
import Data.Vector.Unboxed as U
hiding (and, or, any, all, reverse, findIndex)
import qualified Data.Vector.Unboxed as Unsafe
import safe qualified Data.Vector.Unboxed.Mutable.Bit as B
import Data.Vector.Unboxed.Bit.Internal
import safe Data.Word
import safe Prelude as P
hiding (and, or, any, all, reverse)
wordLength :: U.Vector Bit -> Int
wordLength = nWords . U.length
fromWords :: Int -> U.Vector Word -> U.Vector Bit
fromWords n ws
| n <= m = BitVec 0 n (V.take (nWords n) ws)
| otherwise = pad n (BitVec 0 m ws)
where
m = nBits (V.length ws)
toWords :: U.Vector Bit -> U.Vector Word
toWords v@(BitVec s n ws)
| aligned s && (aligned n || isMasked (modWordSize n) (ws V.! divWordSize n))
= V.slice (divWordSize s) (nWords n) ws
| otherwise = runST (Unsafe.unsafeThaw v >>= cloneWords >>= Unsafe.unsafeFreeze)
{-# INLINE zipWords #-}
zipWords :: (Word -> Word -> Word) -> U.Vector Bit -> U.Vector Bit -> U.Vector Bit
zipWords op xs ys
| V.length xs > V.length ys =
zipWords (flip op) ys xs
| otherwise = runST $ do
xs1 <- V.thaw xs
B.zipInPlace op xs1 ys
Unsafe.unsafeFreeze xs1
{-# INLINE zipMany #-}
zipMany :: Word -> (Word -> Word -> Word) -> Int -> [U.Vector Bit] -> U.Vector Bit
zipMany z op n xss = runST $ do
ys <- MV.new n
B.mapInPlace (const z) ys
P.mapM_ (B.zipInPlace op ys) xss
Unsafe.unsafeFreeze ys
union :: Vector Bit -> Vector Bit -> Vector Bit
union = zipWords (.|.)
intersection :: Vector Bit -> Vector Bit -> Vector Bit
intersection = zipWords (.&.)
difference :: Vector Bit -> Vector Bit -> Vector Bit
difference = zipWords diff
symDiff :: Vector Bit -> Vector Bit -> Vector Bit
symDiff = zipWords xor
unions :: Int -> [U.Vector Bit] -> U.Vector Bit
unions = zipMany 0 (.|.)
intersections :: Int -> [U.Vector Bit] -> U.Vector Bit
intersections = zipMany (complement 0) (.&.)
invert :: U.Vector Bit -> U.Vector Bit
invert xs = runST $ do
ys <- MV.new (V.length xs)
let f i _ = complement (indexWord xs i)
B.mapInPlaceWithIndex f ys
Unsafe.unsafeFreeze ys
select :: (V.Vector v1 Bit, V.Vector v2 t) => v1 Bit -> v2 t -> [t]
select is xs = L.unfoldr next 0
where
n = min (V.length is) (V.length xs)
next j
| j >= n = Nothing
| toBool (is V.! j) = Just (xs V.! j, j + 1)
| otherwise = next (j + 1)
exclude :: (V.Vector v1 Bit, V.Vector v2 t) => v1 Bit -> v2 t -> [t]
exclude is xs = L.unfoldr next 0
where
n = min (V.length is) (V.length xs)
next j
| j >= n = Nothing
| toBool (is V.! j) = next (j + 1)
| otherwise = Just (xs V.! j, j + 1)
selectBits :: U.Vector Bit -> U.Vector Bit -> U.Vector Bit
selectBits is xs = runST $ do
xs1 <- U.thaw xs
n <- B.selectBitsInPlace is xs1
Unsafe.unsafeFreeze (MV.take n xs1)
excludeBits :: U.Vector Bit -> U.Vector Bit -> U.Vector Bit
excludeBits is xs = runST $ do
xs1 <- U.thaw xs
n <- B.excludeBitsInPlace is xs1
Unsafe.unsafeFreeze (MV.take n xs1)
countBits :: U.Vector Bit -> Int
countBits v = loop 0 0
where
!n = alignUp (V.length v)
loop !s !i
| i >= n = s
| otherwise = loop (s + popCount (indexWord v i)) (i + wordSize)
listBits :: U.Vector Bit -> [Int]
listBits v = loop id 0
where
!n = V.length v
loop bs !i
| i >= n = bs []
| otherwise =
loop (bs . bitsInWord i (indexWord v i)) (i + wordSize)
and :: U.Vector Bit -> Bool
and v = loop 0
where
!n = V.length v
loop !i
| i >= n = True
| otherwise = (indexWord v i == mask (n-i))
&& loop (i + wordSize)
or :: U.Vector Bit -> Bool
or v = loop 0
where
!n = V.length v
loop !i
| i >= n = False
| otherwise = (indexWord v i /= 0)
|| loop (i + wordSize)
all :: (Bit -> Bool) -> Vector Bit -> Bool
all p = case (p 0, p 1) of
(False, False) -> U.null
(False, True) -> allBits 1
(True, False) -> allBits 0
(True, True) -> flip seq True
any :: (Bit -> Bool) -> Vector Bit -> Bool
any p = case (p 0, p 1) of
(False, False) -> flip seq False
(False, True) -> anyBits 1
(True, False) -> anyBits 0
(True, True) -> not . U.null
allBits, anyBits :: Bit -> U.Vector Bit -> Bool
allBits (Bit False) = not . or
allBits (Bit True) = and
anyBits (Bit False) = not . and
anyBits (Bit True) = or
reverse :: U.Vector Bit -> U.Vector Bit
reverse xs = runST $ do
let !n = V.length xs
f i _ = reversePartialWord (n - i) (indexWord xs (max 0 (n - i - wordSize)))
ys <- MV.new n
B.mapInPlaceWithIndex f ys
Unsafe.unsafeFreeze ys
first :: Bit -> U.Vector Bit -> Maybe Int
first b xs = mfilter (< n) (loop 0)
where
!n = V.length xs
!ff | toBool b = ffs
| otherwise = ffs . complement
loop !i
| i >= n = Nothing
| otherwise = fmap (i +) (ff (indexWord xs i)) `mplus` loop (i + wordSize)
findIndex :: (Bit -> Bool) -> Vector Bit -> Maybe Int
findIndex p xs = case (p 0, p 1) of
(False, False) -> Nothing
(False, True) -> first 1 xs
(True, False) -> first 0 xs
(True, True) -> if V.null xs then Nothing else Just 0