module Bio.Utils.BitVector
    ( BitVector(..)
    , BitMVector
    , size
    , (!)
    , set
    , clear
    , unsafeFreeze
    , zeros
    , toList
    ) where

import qualified Data.Vector.Unboxed as U
import Control.Monad.Primitive
import qualified Data.Vector.Unboxed.Mutable as UM
import Data.Word
import Data.Bits
import Text.Printf (printf)

data BitVector = BitVector Int (U.Vector Word8)

data BitMVector s = BitMVector Int (UM.MVector s Word8)

size :: BitVector -> Int
size :: BitVector -> Int
size (BitVector Int
n Vector Word8
_) = Int
n

(!) :: BitVector -> Int -> Bool
(!) = BitVector -> Int -> Bool
index

index :: BitVector -> Int -> Bool
index :: BitVector -> Int -> Bool
index (BitVector Int
n Vector Word8
v) Int
idx
    | Int
idx Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
n = [Char] -> Bool
forall a. HasCallStack => [Char] -> a
error ([Char] -> Bool) -> [Char] -> Bool
forall a b. (a -> b) -> a -> b
$ [Char] -> Int -> Int -> [Char]
forall r. PrintfType r => [Char] -> r
printf [Char]
"index out of bounds (%d,%d)" Int
idx Int
n
    | Bool
otherwise = Word8 -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
testBit (Vector Word8
v Vector Word8 -> Int -> Word8
forall a. Unbox a => Vector a -> Int -> a
`U.unsafeIndex` Int
i) Int
j
  where
    i :: Int
i = Int
idx Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
8
    j :: Int
j = Int
idx Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
8

set :: PrimMonad m => BitMVector (PrimState m) -> Int -> m ()
set :: BitMVector (PrimState m) -> Int -> m ()
set (BitMVector Int
_ MVector (PrimState m) Word8
mv) Int
idx = MVector (PrimState m) Word8 -> (Word8 -> Word8) -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
UM.modify MVector (PrimState m) Word8
mv (((Word8 -> Int -> Word8) -> Int -> Word8 -> Word8
forall a b c. (a -> b -> c) -> b -> a -> c
flip Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
setBit) Int
j) Int
i
  where
    i :: Int
i = Int
idx Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
8
    j :: Int
j = Int
idx Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
8

clear :: PrimMonad m => BitMVector (PrimState m) -> Int -> m ()
clear :: BitMVector (PrimState m) -> Int -> m ()
clear (BitMVector Int
_ MVector (PrimState m) Word8
mv) Int
idx = MVector (PrimState m) Word8 -> (Word8 -> Word8) -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
UM.modify MVector (PrimState m) Word8
mv (((Word8 -> Int -> Word8) -> Int -> Word8 -> Word8
forall a b c. (a -> b -> c) -> b -> a -> c
flip Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
clearBit) Int
j) Int
i
  where
    i :: Int
i = Int
idx Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
8
    j :: Int
j = Int
idx Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
8

unsafeFreeze :: PrimMonad m => BitMVector (PrimState m) -> m BitVector
unsafeFreeze :: BitMVector (PrimState m) -> m BitVector
unsafeFreeze (BitMVector Int
n MVector (PrimState m) Word8
mv) = MVector (PrimState m) Word8 -> m (Vector Word8)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
U.unsafeFreeze MVector (PrimState m) Word8
mv m (Vector Word8) -> (Vector Word8 -> m BitVector) -> m BitVector
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= BitVector -> m BitVector
forall (m :: * -> *) a. Monad m => a -> m a
return (BitVector -> m BitVector)
-> (Vector Word8 -> BitVector) -> Vector Word8 -> m BitVector
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Vector Word8 -> BitVector
BitVector Int
n

zeros :: PrimMonad m => Int -> m (BitMVector (PrimState m))
zeros :: Int -> m (BitMVector (PrimState m))
zeros Int
n = Int -> Word8 -> m (MVector (PrimState m) Word8)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
n' Word8
0 m (MVector (PrimState m) Word8)
-> (MVector (PrimState m) Word8 -> m (BitMVector (PrimState m)))
-> m (BitMVector (PrimState m))
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= BitMVector (PrimState m) -> m (BitMVector (PrimState m))
forall (m :: * -> *) a. Monad m => a -> m a
return (BitMVector (PrimState m) -> m (BitMVector (PrimState m)))
-> (MVector (PrimState m) Word8 -> BitMVector (PrimState m))
-> MVector (PrimState m) Word8
-> m (BitMVector (PrimState m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> MVector (PrimState m) Word8 -> BitMVector (PrimState m)
forall s. Int -> MVector s Word8 -> BitMVector s
BitMVector Int
n
  where
    n' :: Int
n' = if Int
j Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 then Int
i else Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
    i :: Int
i = Int
n Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
8
    j :: Int
j = Int
n Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
8

toList :: BitVector -> [Bool]
toList :: BitVector -> [Bool]
toList BitVector
bv = ((Int -> Bool) -> [Int] -> [Bool])
-> [Int] -> (Int -> Bool) -> [Bool]
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Int -> Bool) -> [Int] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map [Int
0..Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> Bool) -> [Bool]) -> (Int -> Bool) -> [Bool]
forall a b. (a -> b) -> a -> b
$ \Int
i -> BitVector
bv BitVector -> Int -> Bool
! Int
i
  where
    n :: Int
n = BitVector -> Int
size BitVector
bv