{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts           #-}
{-# LANGUAGE BangPatterns               #-}
#if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ >= 702
{-# LANGUAGE Trustworthy #-}
#else
#define safe
#endif
module Data.Vector.Unboxed.Mutable.Bit
     ( module Data.Bit
     , module U
     
     , wordSize
     , wordLength
     , cloneFromWords
     , cloneToWords
     , readWord
     , writeWord
     
     , mapMInPlaceWithIndex
     , mapInPlaceWithIndex
     , mapMInPlace
     , mapInPlace
     
     , zipInPlace
     
     , unionInPlace
     , intersectionInPlace
     , differenceInPlace
     , symDiffInPlace
     , invertInPlace
     , selectBitsInPlace
     , excludeBitsInPlace
     
     , countBits
     , listBits
     
     , and
     , or
     
     , any
     , anyBits
     , all
     , allBits
     
     , reverseInPlace
     ) where

import safe           Control.Monad
import                Control.Monad.Primitive
import safe           Data.Bit
import safe           Data.Bit.Internal
import safe           Data.Bits
import      qualified Data.Vector.Generic.Mutable       as MV
import safe qualified Data.Vector.Generic.Safe          as V
import safe qualified Data.Vector.Unboxed.Safe          as U (Vector)
import safe           Data.Vector.Unboxed.Mutable.Safe  as U
import                Data.Vector.Unboxed.Bit.Internal
import safe           Data.Word
import safe           Prelude                           as P
    hiding (and, or, any, all, reverse)


-- TODO: this interface needs more work.

-- |Get the length of the vector that would be created by 'cloneToWords'
wordLength :: U.MVector s Bit -> Int
wordLength = nWords . MV.length

-- |Clone a specified number of bits from a vector of words into a new vector of bits (interpreting the words in little-endian order, as described at 'indexWord').  If there are not enough words for the number of bits requested, the vector will be zero-padded.
cloneFromWords :: PrimMonad m => Int -> U.MVector (PrimState m) Word -> m (U.MVector (PrimState m) Bit)
cloneFromWords n ws = do
    let wordsNeeded = nWords n
        wordsGiven  = MV.length ws
        fillNeeded  = wordsNeeded - wordsGiven
    
    v <- MV.new wordsNeeded
    
    if fillNeeded > 0
        then do
            MV.copy (MV.slice          0 wordsGiven v) ws
            MV.set  (MV.slice wordsGiven fillNeeded v) 0
        else do
            MV.copy v (MV.slice 0 wordsNeeded ws)
    
    return (BitMVec 0 n v)

-- |clone a vector of bits to a new unboxed vector of words.  If the bits don't completely fill the words, the last word will be zero-padded.
cloneToWords :: PrimMonad m => U.MVector (PrimState m) Bit -> m (U.MVector (PrimState m) Word)
cloneToWords v@(BitMVec s n ws)
    | aligned s = do
        ws <- MV.clone (MV.slice (divWordSize s) (nWords n) ws)
        when (not (aligned n)) $ do
            readWord v (alignDown n) >>= MV.write ws (divWordSize n)
        return ws
    | otherwise = cloneWords v

-- |Map a function over a bit vector one 'Word' at a time ('wordSize' bits at a time).  The function will be passed the bit index (which will always be 'wordSize'-aligned) and the current value of the corresponding word.  The returned word will be written back to the vector.  If there is a partial word at the end of the vector, it will be zero-padded when passed to the function and truncated when the result is written back to the array.
{-# INLINE mapMInPlaceWithIndex #-}
mapMInPlaceWithIndex ::
    PrimMonad m =>
        (Int -> Word -> m Word)
     -> U.MVector (PrimState m) Bit -> m ()
mapMInPlaceWithIndex f xs@(BitMVec 0 n v) = loop 0 0
    where
        !n_ = alignDown (MV.length xs)
        loop !i !j
            | i >= n_   = when (n_ /= MV.length xs) $ do
                readWord xs i >>= f i >>= writeWord xs i
                
            | otherwise = do
                MV.read v j >>= f i >>= MV.write v j
                loop (i + wordSize) (j + 1)
mapMInPlaceWithIndex f xs = loop 0
    where
        !n = MV.length xs
        loop !i
            | i >= n    = return ()
            | otherwise = do
                readWord xs i >>= f i >>= writeWord xs i
                loop (i + wordSize)

{-# INLINE mapInPlaceWithIndex #-}
mapInPlaceWithIndex ::
    PrimMonad m =>
        (Int -> Word -> Word)
     -> U.MVector (PrimState m) Bit -> m ()
mapInPlaceWithIndex f = mapMInPlaceWithIndex g
    where
        {-# INLINE g #-}
        g i x = return $! f i x

-- |Same as 'mapMInPlaceWithIndex' but without the index.
{-# INLINE mapMInPlace #-}
mapMInPlace :: PrimMonad m => (Word -> m Word) -> U.MVector (PrimState m) Bit -> m ()
mapMInPlace f = mapMInPlaceWithIndex (const f)

{-# INLINE mapInPlace #-}
mapInPlace :: PrimMonad m => (Word -> Word) -> U.MVector (PrimState m) Bit -> m ()
mapInPlace f = mapMInPlaceWithIndex (\_ x -> return (f x))

{-# INLINE zipInPlace #-}
zipInPlace :: PrimMonad m => (Word -> Word -> Word) -> U.MVector (PrimState m) Bit -> U.Vector Bit -> m ()
zipInPlace f xs ys@(BitVec 0 n2 v) =
    mapInPlaceWithIndex g (MV.basicUnsafeSlice 0 n xs)
    where
        -- WARNING: relies on guarantee by mapMInPlaceWithIndex that index will always be aligned!
        !n = min (MV.length xs) (V.length ys)
        {-# INLINE g #-}
        g !i !x = 
            let !w = masked (n2 - i) (v V.! divWordSize i)
             in f x w
zipInPlace f xs ys =
    mapInPlaceWithIndex g (MV.basicUnsafeSlice 0 n xs)
    where 
        !n = min (MV.length xs) (V.length ys)
        {-# INLINE g #-}
        g !i !x = 
            let !w = indexWord ys i
             in f x w

unionInPlace :: PrimMonad m => U.MVector (PrimState m) Bit -> U.Vector Bit -> m ()
unionInPlace = zipInPlace (.|.)

intersectionInPlace :: PrimMonad m => U.MVector (PrimState m) Bit -> U.Vector Bit -> m ()
intersectionInPlace = zipInPlace (.&.)

differenceInPlace :: PrimMonad m => U.MVector (PrimState m) Bit -> U.Vector Bit -> m ()
differenceInPlace = zipInPlace diff

symDiffInPlace :: PrimMonad m => U.MVector (PrimState m) Bit -> U.Vector Bit -> m ()
symDiffInPlace = zipInPlace xor

-- |Flip every bit in the given vector
invertInPlace :: PrimMonad m => U.MVector (PrimState m) Bit -> m ()
invertInPlace = mapInPlace complement

selectBitsInPlace :: PrimMonad m => U.Vector Bit -> U.MVector (PrimState m) Bit -> m Int
selectBitsInPlace is xs = loop 0 0
    where
        !n = min (V.length is) (MV.length xs)
        loop !i !ct
            | i >= n    = return ct
            | otherwise = do
                x <- readWord xs i
                let !(nSet, x') = selectWord (masked (n - i) (indexWord is i)) x
                writeWord xs ct x'
                loop (i + wordSize) (ct + nSet)

excludeBitsInPlace :: PrimMonad m => U.Vector Bit -> U.MVector (PrimState m) Bit -> m Int
excludeBitsInPlace is xs = loop 0 0
    where
        !n = min (V.length is) (MV.length xs)
        loop !i !ct
            | i >= n    = return ct
            | otherwise = do
                x <- readWord xs i
                let !(nSet, x') = selectWord (masked (n - i) (complement (indexWord is i))) x
                writeWord xs ct x'
                loop (i + wordSize) (ct + nSet)

-- |return the number of ones in a bit vector
countBits :: PrimMonad m => U.MVector (PrimState m) Bit -> m Int
countBits v = loop 0 0
    where
        !n = alignUp (MV.length v)
        loop !s !i
            | i >= n    = return s
            | otherwise = do
                x <- readWord v i
                loop (s + popCount x) (i + wordSize)

listBits :: PrimMonad m => U.MVector (PrimState m) Bit -> m [Int]
listBits v = loop id 0
    where
        !n = MV.length v
        loop bs !i
            | i >= n    = return $! bs []
            | otherwise = do
                w <- readWord v i
                loop (bs . bitsInWord i w) (i + wordSize)

-- | Returns 'True' if all bits in the vector are set
and :: PrimMonad m => U.MVector (PrimState m) Bit -> m Bool
and v = loop 0
    where
        !n = MV.length v
        loop !i
            | i >= n    = return True
            | otherwise = do
                y <- readWord v i
                if y == mask (n - i)
                    then loop (i + wordSize)
                    else return False

-- | Returns 'True' if any bit in the vector is set
or :: PrimMonad m => U.MVector (PrimState m) Bit -> m Bool
or v = loop 0
    where
        !n = MV.length v
        loop !i
            | i >= n    = return False
            | otherwise = do
                y <- readWord v i
                if y /= 0
                    then return True
                    else loop (i + wordSize)

all :: PrimMonad m => (Bit -> Bool) -> U.MVector (PrimState m) Bit -> m Bool
all p = case (p 0, p 1) of
    (False, False) -> return . MV.null
    (False,  True) -> allBits 1
    (True,  False) -> allBits 0
    (True,   True) -> flip seq (return True)

any :: PrimMonad m => (Bit -> Bool) -> U.MVector (PrimState m) Bit -> m Bool
any p = case (p 0, p 1) of
    (False, False) -> flip seq (return False)
    (False,  True) -> anyBits 1
    (True,  False) -> anyBits 0
    (True,   True) -> return . not . MV.null

allBits, anyBits :: PrimMonad m => Bit -> U.MVector (PrimState m) Bit -> m Bool
allBits 0 = liftM not . or
allBits 1 = and

anyBits 0 = liftM not . and
anyBits 1 = or

reverseInPlace :: PrimMonad m => U.MVector (PrimState m) Bit -> m ()
reverseInPlace xs = loop 0 (MV.length xs)
    where
        loop !i !j
            | i' <= j'  = do
                x <- readWord xs i
                y <- readWord xs j'
                
                writeWord xs i  (reverseWord y)
                writeWord xs j' (reverseWord x)
                
                loop i' j'
            | i' < j    = do
                let w = (j - i) `shiftR` 1
                    k  = j - w
                x <- readWord xs i
                y <- readWord xs k
                
                writeWord xs i (meld w (reversePartialWord w y) x)
                writeWord xs k (meld w (reversePartialWord w x) y)
                
                loop i' j'
            | i  < j    = do
                let w = j - i
                x <- readWord xs i
                writeWord xs i (meld w (reversePartialWord w x) x)
            | otherwise = return ()
            where    
                !i' = i + wordSize
                !j' = j - wordSize