module Foundation.Array.Bitmap
    ( Bitmap
    , MutableBitmap
    , empty
    , append
    , concat
    , unsafeIndex
    , index
    , read
    , unsafeRead
    , write
    , unsafeWrite
    , snoc
    , cons
    ) where
import           Foundation.Array.Unboxed (UArray)
import qualified Foundation.Array.Unboxed as A
import           Foundation.Array.Unboxed.Mutable (MUArray)
import           Foundation.Array.Common
import           Foundation.Internal.Base
import           Foundation.Internal.Types
import           Foundation.Primitive.Monad
import qualified Foundation.Collection as C
import           Foundation.Numerical
import           Data.Bits
import           Foundation.Bits
import           GHC.ST
import qualified Data.List
data Bitmap = Bitmap Int (UArray Word32)
data MutableBitmap st = MutableBitmap Int (MUArray Word32 st)
bitsPerTy :: Int
bitsPerTy = 32
shiftPerTy :: Int
shiftPerTy = 5
maskPerTy :: Int
maskPerTy = 0x1f
instance Show Bitmap where
    show v = show (toList v)
instance Eq Bitmap where
    (==) = equal
instance Ord Bitmap where
    compare = vCompare
instance Monoid Bitmap where
    mempty  = empty
    mappend = append
    mconcat = concat
type instance C.Element Bitmap = Bool
instance IsList Bitmap where
    type Item Bitmap = Bool
    fromList = vFromList
    toList = vToList
instance C.InnerFunctor Bitmap where
    imap = map
instance C.Foldable Bitmap where
    foldl = foldl
    foldr = foldr
    foldl' = foldl'
    foldr' = foldr'
instance C.Collection Bitmap where
    null = null
    length = length
    elem e = Data.List.elem e . toList
    minimum = Data.List.minimum . toList . C.getNonEmpty 
    maximum = Data.List.maximum . toList . C.getNonEmpty 
    all p = Data.List.all p . toList
    any p = Data.List.any p . toList
instance C.Sequential Bitmap where
    take = take
    drop = drop
    splitAt = splitAt
    revTake n = unoptimised (C.revTake n)
    revDrop n = unoptimised (C.revDrop n)
    splitOn = splitOn
    break = break
    span = span
    filter = filter
    reverse = reverse
    snoc = snoc
    cons = cons
    unsnoc = unsnoc
    uncons = uncons
    intersperse = intersperse
    find = find
    sortBy = sortBy
    singleton = fromList . (:[])
instance C.IndexedCollection Bitmap where
    (!) l n
        | n < 0 || n >= length l = Nothing
        | otherwise              = Just $ index l n
    findIndex predicate c = loop 0
      where
        !len = length c
        loop i
            | i == len                    = Nothing
            | predicate (unsafeIndex c i) = Just i
            | otherwise                   = Nothing
instance C.MutableCollection MutableBitmap where
    type MutableFreezed MutableBitmap = Bitmap
    type MutableKey MutableBitmap = Int
    type MutableValue MutableBitmap = Bool
    thaw = thaw
    freeze = freeze
    unsafeThaw = unsafeThaw
    unsafeFreeze = unsafeFreeze
    mutNew n = new (Size n)
    mutUnsafeWrite = unsafeWrite
    mutUnsafeRead = unsafeRead
    mutWrite = write
    mutRead = read
bitmapIndex :: Offset Bool -> (Int, Int)
bitmapIndex (Offset !i) = (i .>>. shiftPerTy, i .&. maskPerTy)
thaw :: PrimMonad prim => Bitmap -> prim (MutableBitmap (PrimState prim))
thaw (Bitmap len ba) = MutableBitmap len `fmap` C.thaw ba
freeze :: PrimMonad prim => MutableBitmap (PrimState prim) -> prim Bitmap
freeze (MutableBitmap len mba) = Bitmap len `fmap` C.freeze mba
unsafeThaw :: PrimMonad prim => Bitmap -> prim (MutableBitmap (PrimState prim))
unsafeThaw (Bitmap len ba) = MutableBitmap len `fmap` C.unsafeThaw ba
unsafeFreeze :: PrimMonad prim => MutableBitmap (PrimState prim) -> prim Bitmap
unsafeFreeze (MutableBitmap len mba) = Bitmap len `fmap` C.unsafeFreeze mba
unsafeWrite :: PrimMonad prim => MutableBitmap (PrimState prim) -> Int -> Bool -> prim ()
unsafeWrite (MutableBitmap _ ma) i v = do
    let (idx, bitIdx) = bitmapIndex (Offset i)
    w <- A.unsafeRead ma idx
    let w' = if v then setBit w bitIdx else clearBit w bitIdx
    A.unsafeWrite ma idx w'
unsafeRead :: PrimMonad prim => MutableBitmap (PrimState prim) -> Int -> prim Bool
unsafeRead (MutableBitmap _ ma) i = do
    let (idx, bitIdx) = bitmapIndex (Offset i)
    flip testBit bitIdx `fmap` A.unsafeRead ma idx
write :: PrimMonad prim => MutableBitmap (PrimState prim) -> Int -> Bool -> prim ()
write mb n val
    | n < 0 || n >= len = primThrow (OutOfBound OOB_Write n len)
    | otherwise         = unsafeWrite mb n val
  where
    len = mutableLength mb
read :: PrimMonad prim => MutableBitmap (PrimState prim) -> Int -> prim Bool
read mb n
    | n < 0 || n >= len = primThrow (OutOfBound OOB_Read n len)
    | otherwise         = unsafeRead mb n
  where len = mutableLength mb
index :: Bitmap -> Int -> Bool
index bits n
    | n < 0 || n >= len = throw (OutOfBound OOB_Index n len)
    | otherwise         = unsafeIndex bits n
  where len = length bits
unsafeIndex :: Bitmap -> Int -> Bool
unsafeIndex (Bitmap _ ba) n =
    let (idx, bitIdx) = bitmapIndex (Offset n)
     in testBit (A.unsafeIndex ba idx) bitIdx
length :: Bitmap -> Int
length (Bitmap len _) = len
mutableLength :: MutableBitmap st -> Int
mutableLength (MutableBitmap len _) = len
empty :: Bitmap
empty = Bitmap 0 A.empty
new :: PrimMonad prim => Size Bool -> prim (MutableBitmap (PrimState prim))
new (Size len) =
    MutableBitmap len <$> A.new nbElements
  where
    nbElements :: Size Word32
    nbElements = Size ((len `alignRoundUp` bitsPerTy) .>>. shiftPerTy)
vFromList :: [Bool] -> Bitmap
vFromList allBools = runST $ do
    mbitmap <- new (Size len)
    loop mbitmap 0 allBools
  where
    loop mb _ []     = unsafeFreeze mb
    loop mb i (x:xs) = unsafeWrite mb i x >> loop mb (i+1) xs
    len        = C.length allBools
vToList :: Bitmap -> [Bool]
vToList a = loop 0
  where len = length a
        loop i | i == len  = []
               | otherwise = unsafeIndex a i : loop (i+1)
equal :: Bitmap -> Bitmap -> Bool
equal a b
    | la /= lb  = False
    | otherwise = loop 0
  where
    !la = length a
    !lb = length b
    loop n | n == la    = True
           | otherwise = (unsafeIndex a n == unsafeIndex b n) && loop (n+1)
vCompare :: Bitmap -> Bitmap -> Ordering
vCompare a b = loop 0
  where
    !la = length a
    !lb = length b
    loop n
        | n == la   = if la == lb then EQ else LT
        | n == lb   = GT
        | otherwise =
            case unsafeIndex a n `compare` unsafeIndex b n of
                EQ -> loop (n+1)
                r  -> r
append :: Bitmap -> Bitmap -> Bitmap
append a b = fromList $ toList a `mappend` toList b
concat :: [Bitmap] -> Bitmap
concat l = fromList $ mconcat $ fmap toList l
null :: Bitmap -> Bool
null (Bitmap nbBits _) = nbBits == 0
take :: Int -> Bitmap -> Bitmap
take nbElems bits@(Bitmap nbBits ba)
    | nbElems <= 0     = empty
    | nbElems >= nbBits = bits
    | otherwise        = Bitmap nbElems ba 
drop :: Int -> Bitmap -> Bitmap
drop nbElems bits@(Bitmap nbBits _)
    | nbElems <= 0      = bits
    | nbElems >= nbBits = empty
    | otherwise         = unoptimised (C.drop nbElems) bits
        
        
splitAt :: Int -> Bitmap -> (Bitmap, Bitmap)
splitAt n v = (take n v, drop n v)
splitOn :: (Bool -> Bool) -> Bitmap -> [Bitmap]
splitOn f bits = fmap fromList $ C.splitOn f $ toList bits
break :: (Bool -> Bool) -> Bitmap -> (Bitmap, Bitmap)
break predicate v = findBreak 0
  where
    findBreak i
        | i == length v = (v, empty)
        | otherwise     =
            if predicate (unsafeIndex v i)
                then splitAt i v
                else findBreak (i+1)
span :: (Bool -> Bool) -> Bitmap -> (Bitmap, Bitmap)
span p = break (not . p)
map :: (Bool -> Bool) -> Bitmap -> Bitmap
map f bits = unoptimised (fmap f) bits
cons :: Bool -> Bitmap -> Bitmap
cons v l = unoptimised (C.cons v) l
snoc :: Bitmap -> Bool -> Bitmap
snoc l v = unoptimised (flip C.snoc v) l
uncons :: Bitmap -> Maybe (Bool, Bitmap)
uncons b = fmap (\(v, l) -> (v, fromList l)) $ C.uncons $ toList b
unsnoc :: Bitmap -> Maybe (Bitmap, Bool)
unsnoc b = fmap (\(l, v) -> (fromList l, v)) $ C.unsnoc $ toList b
intersperse :: Bool -> Bitmap -> Bitmap
intersperse b = unoptimised (C.intersperse b)
find :: (Bool -> Bool) -> Bitmap -> Maybe Bool
find predicate vec = loop 0
  where
    !len = length vec
    loop i
        | i == len  = Nothing
        | otherwise =
            let e = unsafeIndex vec i
             in if predicate e then Just e else loop (i+1)
sortBy :: (Bool -> Bool -> Ordering) -> Bitmap -> Bitmap
sortBy by bits = unoptimised (C.sortBy by) bits
filter :: (Bool -> Bool) -> Bitmap -> Bitmap
filter predicate vec = unoptimised (Data.List.filter predicate) vec
reverse :: Bitmap -> Bitmap
reverse bits = unoptimised C.reverse bits
foldl :: (a -> Bool -> a) -> a -> Bitmap -> a
foldl f initialAcc vec = loop 0 initialAcc
  where
    len = length vec
    loop i acc
        | i == len  = acc
        | otherwise = loop (i+1) (f acc (unsafeIndex vec i))
foldr :: (Bool -> a -> a) -> a -> Bitmap -> a
foldr f initialAcc vec = loop 0
  where
    len = length vec
    loop i
        | i == len  = initialAcc
        | otherwise = unsafeIndex vec i `f` loop (i+1)
foldr' :: (Bool -> a -> a) -> a -> Bitmap -> a
foldr' = foldr
foldl' :: (a -> Bool -> a) -> a -> Bitmap -> a
foldl' f initialAcc vec = loop 0 initialAcc
  where
    len = length vec
    loop i !acc
        | i == len  = acc
        | otherwise = loop (i+1) (f acc (unsafeIndex vec i))
unoptimised :: ([Bool] -> [Bool]) -> Bitmap -> Bitmap
unoptimised f = vFromList . f . vToList