{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE MultiParamTypeClasses #-}

module HaskellWorks.Data.RankSelect.BitSeq
  ( BitSeq(..)
  , mempty
  , size
  , fromWord64s
  , fromPartialWord64s
  , toPartialWord64s
  , fromBools
  , toBools
  , BS.splitAt
  , take
  , drop
  , (<|), (><), (|>)
  , select1
  ) where

import Data.Coerce
import Data.Foldable
import Data.Word
import HaskellWorks.Data.Bits.BitWise
import HaskellWorks.Data.FingerTree                 (ViewL (..), ViewR (..), (<|), (><), (|>))
import HaskellWorks.Data.Positioning
import HaskellWorks.Data.RankSelect.Base.Select1    (select1)
import HaskellWorks.Data.RankSelect.Internal.BitSeq (BitSeq (BitSeq), BitSeqFt, Elem (Elem))
import Prelude                                      hiding (drop, max, min, splitAt, take)

import qualified Data.List                                    as L
import qualified HaskellWorks.Data.FingerTree                 as FT
import qualified HaskellWorks.Data.RankSelect.Internal.BitSeq as BS
import qualified HaskellWorks.Data.RankSelect.Internal.Word   as W

empty :: BitSeq
empty :: BitSeq
empty = BitSeqFt -> BitSeq
BitSeq BitSeqFt
forall v a. Measured v a => FingerTree v a
FT.empty

size :: BitSeq -> Count
size :: BitSeq -> Count
size (BitSeq BitSeqFt
parens) = Measure -> Count
BS.measureBitCount (BitSeqFt -> Measure
forall v a. Measured v a => a -> v
FT.measure BitSeqFt
parens)

-- TODO Needs optimisation
fromWord64s :: Traversable f => f Word64 -> BitSeq
fromWord64s :: f Count -> BitSeq
fromWord64s = (BitSeq -> Count -> BitSeq) -> BitSeq -> f Count -> BitSeq
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl BitSeq -> Count -> BitSeq
go BitSeq
empty
  where go :: BitSeq -> Word64 -> BitSeq
        go :: BitSeq -> Count -> BitSeq
go BitSeq
ps Count
w = BitSeqFt -> BitSeq
BitSeq (BitSeq -> BitSeqFt
BS.parens BitSeq
ps BitSeqFt -> Elem BitSeqFt -> BitSeqFt
forall v. Snoc v => v -> Elem v -> v
|> Count -> Count -> Elem
Elem Count
w Count
64)

-- TODO Needs optimisation
fromPartialWord64s :: Traversable f => f (Word64, Count) -> BitSeq
fromPartialWord64s :: f (Count, Count) -> BitSeq
fromPartialWord64s = (BitSeq -> (Count, Count) -> BitSeq)
-> BitSeq -> f (Count, Count) -> BitSeq
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl BitSeq -> (Count, Count) -> BitSeq
go BitSeq
empty
  where go :: BitSeq -> (Word64, Count) -> BitSeq
        go :: BitSeq -> (Count, Count) -> BitSeq
go BitSeq
ps (Count
w, Count
n) = BitSeqFt -> BitSeq
BitSeq (BitSeq -> BitSeqFt
BS.parens BitSeq
ps BitSeqFt -> Elem BitSeqFt -> BitSeqFt
forall v. Snoc v => v -> Elem v -> v
|> Count -> Count -> Elem
Elem Count
w Count
n)

toPartialWord64s :: BitSeq -> [(Word64, Count)]
toPartialWord64s :: BitSeq -> [(Count, Count)]
toPartialWord64s = (BitSeqFt -> Maybe ((Count, Count), BitSeqFt))
-> BitSeqFt -> [(Count, Count)]
forall b a. (b -> Maybe (a, b)) -> b -> [a]
L.unfoldr BitSeqFt -> Maybe ((Count, Count), BitSeqFt)
go (BitSeqFt -> [(Count, Count)])
-> (BitSeq -> BitSeqFt) -> BitSeq -> [(Count, Count)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BitSeq -> BitSeqFt
coerce
  where go :: BitSeqFt -> Maybe ((Word64, Count), BitSeqFt)
        go :: BitSeqFt -> Maybe ((Count, Count), BitSeqFt)
go BitSeqFt
ft = case BitSeqFt -> ViewL (FingerTree Measure) Elem
forall v a.
Measured v a =>
FingerTree v a -> ViewL (FingerTree v) a
FT.viewl BitSeqFt
ft of
          BS.Elem Count
w Count
n :< BitSeqFt
rt -> ((Count, Count), BitSeqFt) -> Maybe ((Count, Count), BitSeqFt)
forall a. a -> Maybe a
Just ((Count
w, Count -> Count
coerce Count
n), BitSeqFt
rt)
          ViewL (FingerTree Measure) Elem
FT.EmptyL         -> Maybe ((Count, Count), BitSeqFt)
forall a. Maybe a
Nothing

fromBools :: [Bool] -> BitSeq
fromBools :: [Bool] -> BitSeq
fromBools = BitSeq -> [Bool] -> BitSeq
go BitSeq
empty
  where go :: BitSeq -> [Bool] -> BitSeq
        go :: BitSeq -> [Bool] -> BitSeq
go (BitSeq BitSeqFt
ps) (Bool
b:[Bool]
bs) = case BitSeqFt -> ViewR (FingerTree Measure) Elem
forall v a.
Measured v a =>
FingerTree v a -> ViewR (FingerTree v) a
FT.viewr BitSeqFt
ps of
          ViewR (FingerTree Measure) Elem
FT.EmptyR      -> BitSeq -> [Bool] -> BitSeq
go (BitSeqFt -> BitSeq
BitSeq (Elem -> BitSeqFt
forall v a. Measured v a => a -> FingerTree v a
FT.singleton (Count -> Count -> Elem
Elem Count
b' Count
1))) [Bool]
bs
          BitSeqFt
lt :> Elem Count
w Count
n ->
            let newPs :: BitSeqFt
newPs = if Count
n Count -> Count -> Bool
forall a. Ord a => a -> a -> Bool
>= Count
64
                then BitSeqFt
ps BitSeqFt -> Elem BitSeqFt -> BitSeqFt
forall v. Snoc v => v -> Elem v -> v
|> Count -> Count -> Elem
Elem Count
b' Count
1
                else BitSeqFt
lt BitSeqFt -> Elem BitSeqFt -> BitSeqFt
forall v. Snoc v => v -> Elem v -> v
|> Count -> Count -> Elem
Elem (Count
w Count -> Count -> Count
forall a. BitWise a => a -> a -> a
.|. (Count
b' Count -> Count -> Count
forall a. Shift a => a -> Count -> a
.<. Count -> Count
forall a b. (Integral a, Num b) => a -> b
fromIntegral Count
n)) (Count
n Count -> Count -> Count
forall a. Num a => a -> a -> a
+ Count
1)
            in BitSeq -> [Bool] -> BitSeq
go (BitSeqFt -> BitSeq
BitSeq BitSeqFt
newPs) [Bool]
bs
          where b' :: Count
b' = if Bool
b then Count
1 else Count
0 :: Word64
        go BitSeq
ps [] = BitSeq
ps

toBools :: BitSeq -> [Bool]
toBools :: BitSeq -> [Bool]
toBools BitSeq
ps = BitSeq -> [Bool] -> [Bool]
toBoolsDiff BitSeq
ps []

toBoolsDiff :: BitSeq -> [Bool] -> [Bool]
toBoolsDiff :: BitSeq -> [Bool] -> [Bool]
toBoolsDiff BitSeq
ps = [[Bool] -> [Bool]] -> [Bool] -> [Bool]
forall a. Monoid a => [a] -> a
mconcat (((Count, Count) -> [Bool] -> [Bool])
-> [(Count, Count)] -> [[Bool] -> [Bool]]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Count, Count) -> [Bool] -> [Bool]
go (BitSeq -> [(Count, Count)]
toPartialWord64s BitSeq
ps))
  where go :: (Word64, Count) -> [Bool] -> [Bool]
        go :: (Count, Count) -> [Bool] -> [Bool]
go (Count
w, Count
n) = Count -> Count -> [Bool] -> [Bool]
W.partialToBoolsDiff (Count -> Count
forall a b. (Integral a, Num b) => a -> b
fromIntegral Count
n) Count
w

drop :: Count -> BitSeq -> BitSeq
drop :: Count -> BitSeq -> BitSeq
drop Count
n BitSeq
ps = (BitSeq, BitSeq) -> BitSeq
forall a b. (a, b) -> b
snd (Count -> BitSeq -> (BitSeq, BitSeq)
BS.splitAt Count
n BitSeq
ps)

take :: Count -> BitSeq -> BitSeq
take :: Count -> BitSeq -> BitSeq
take Count
n BitSeq
ps = (BitSeq, BitSeq) -> BitSeq
forall a b. (a, b) -> a
fst (Count -> BitSeq -> (BitSeq, BitSeq)
BS.splitAt Count
n BitSeq
ps)