module Z.Data.Vector.FlatSet
  ( 
    FlatSet, sortedValues, size, null, empty, map'
  , pack, packN, packR, packRN
  , unpack, unpackR, packVector, packVectorR
  , elem
  , delete
  , insert
  , merge
    
  , binarySearch
  ) where
import           Control.DeepSeq
import           Control.Monad
import           Control.Monad.ST
import qualified Data.Primitive.SmallArray  as A
import qualified Data.Semigroup             as Semigroup
import qualified Data.Monoid                as Monoid
import qualified Z.Data.Vector.Base         as V
import qualified Z.Data.Vector.Sort         as V
import qualified Z.Data.Text.ShowT          as T
import           Data.Bits                   (shiftR)
import           Data.Data
import           Prelude hiding (elem, null)
import           Test.QuickCheck.Arbitrary (Arbitrary(..), CoArbitrary(..))
newtype FlatSet v = FlatSet { sortedValues :: V.Vector v }
    deriving (Show, Eq, Ord, Typeable, Foldable)
instance T.ShowT v => T.ShowT (FlatSet v) where
    {-# INLINE toTextBuilder #-}
    toTextBuilder p (FlatSet vec) = T.parenWhen (p > 10) $ do
        T.unsafeFromBuilder "FlatSet {"
        T.intercalateVec T.comma (T.toTextBuilder 0) vec
        T.char7 '}'
instance Ord v => Semigroup.Semigroup (FlatSet v) where
    {-# INLINE (<>) #-}
    (<>) = merge
instance Ord v => Monoid.Monoid (FlatSet v) where
    {-# INLINE mappend #-}
    mappend = merge
    {-# INLINE mempty #-}
    mempty = empty
instance NFData v => NFData (FlatSet v) where
    {-# INLINE rnf #-}
    rnf (FlatSet vs) = rnf vs
instance (Ord v, Arbitrary v) => Arbitrary (FlatSet v) where
    arbitrary = pack <$> arbitrary
    shrink v = pack <$> shrink (unpack v)
instance (CoArbitrary v) => CoArbitrary (FlatSet v) where
    coarbitrary = coarbitrary . unpack
size :: FlatSet v -> Int
{-# INLINE size #-}
size = V.length . sortedValues
null :: FlatSet v -> Bool
{-# INLINE null #-}
null = V.null . sortedValues
map' :: forall v. Ord v => (v -> v) -> FlatSet v -> FlatSet v
{-# INLINE map' #-}
map' f (FlatSet vs) = packVector (V.map' f vs :: V.Vector v)
empty :: FlatSet v
{-# INLINE empty #-}
empty = FlatSet V.empty
pack :: Ord v => [v] -> FlatSet v
{-# INLINE pack #-}
pack vs = FlatSet (V.mergeDupAdjacentLeft (==) (V.mergeSort (V.pack vs)))
packN :: Ord v => Int -> [v] -> FlatSet v
{-# INLINE packN #-}
packN n vs = FlatSet (V.mergeDupAdjacentLeft (==) (V.mergeSort (V.packN n vs)))
packR :: Ord v => [v] -> FlatSet v
{-# INLINE packR #-}
packR vs = FlatSet (V.mergeDupAdjacentRight (==) (V.mergeSort (V.pack vs)))
packRN :: Ord v => Int -> [v] -> FlatSet v
{-# INLINE packRN #-}
packRN n vs = FlatSet (V.mergeDupAdjacentRight (==) (V.mergeSort (V.packN n vs)))
unpack :: FlatSet v -> [v]
{-# INLINE unpack #-}
unpack = V.unpack . sortedValues
unpackR :: FlatSet v -> [v]
{-# INLINE unpackR #-}
unpackR = V.unpackR . sortedValues
packVector :: Ord v => V.Vector v -> FlatSet v
{-# INLINE packVector #-}
packVector vs = FlatSet (V.mergeDupAdjacentLeft (==) (V.mergeSort vs))
packVectorR :: Ord v => V.Vector v -> FlatSet v
{-# INLINE packVectorR #-}
packVectorR vs = FlatSet (V.mergeDupAdjacentRight (==) (V.mergeSort vs))
elem :: Ord v => v -> FlatSet v -> Bool
{-# INLINE elem #-}
elem v (FlatSet vec) = case binarySearch vec v of Left _ -> False
                                                  _      -> True
insert :: Ord v => v -> FlatSet v -> FlatSet v
{-# INLINE insert #-}
insert v m@(FlatSet vec@(V.Vector arr s l)) =
    case binarySearch vec v of
        Left i -> FlatSet (V.create (l+1) (\ marr -> do
            when (i>s) $ A.copySmallArray marr 0 arr s (i-s)
            A.writeSmallArray marr i v
            when (i<(s+l)) $ A.copySmallArray marr (i+1) arr i (s+l-i)))
        Right _ -> m
delete :: Ord v => v -> FlatSet v -> FlatSet v
{-# INLINE delete #-}
delete v m@(FlatSet vec@(V.Vector arr s l)) =
    case binarySearch vec v of
        Left _ -> m
        Right i -> FlatSet $ V.create (l-1) (\ marr -> do
            when (i>s) $ A.copySmallArray marr 0 arr s (i-s)
            let !end = s+l
                !j = i+1
            when (end > j) $ A.copySmallArray marr 0 arr j (end-j))
merge :: forall v . Ord v => FlatSet v -> FlatSet v -> FlatSet v
{-# INLINE merge #-}
merge fmL@(FlatSet (V.Vector arrL sL lL)) fmR@(FlatSet (V.Vector arrR sR lR))
    | null fmL = fmR
    | null fmR = fmL
    | otherwise = FlatSet (V.createN (lL+lR) (go sL sR 0))
  where
    endL = sL + lL
    endR = sR + lR
    go :: Int -> Int -> Int -> A.SmallMutableArray s v -> ST s Int
    go !i !j !k marr
        | i >= endL = do
            A.copySmallArray marr k arrR j (lR-j)
            return $! k+lR-j
        | j >= endR = do
            A.copySmallArray marr k arrL i (lL-i)
            return $! k+lL-i
        | otherwise = do
            vL <- arrL `A.indexSmallArrayM` i
            vR <- arrR `A.indexSmallArrayM` j
            case vL `compare` vR of LT -> do A.writeSmallArray marr k vL
                                             go (i+1) j (k+1) marr
                                    EQ -> do A.writeSmallArray marr k vR
                                             go (i+1) (j+1) (k+1) marr
                                    _  -> do A.writeSmallArray marr k vR
                                             go i (j+1) (k+1) marr
binarySearch :: Ord v => V.Vector v -> v -> Either Int Int
{-# INLINABLE binarySearch #-}
binarySearch (V.Vector _ _ 0) _   = Left 0
binarySearch (V.Vector arr s0 l) !v' = go s0 (s0+l-1)
  where
    go !s !e
        | s == e =
            let v = arr `A.indexSmallArray` s
            in case v' `compare` v of LT -> Left s
                                      GT -> let !s' = s+1 in Left s'
                                      _  -> Right s
        | s >  e = Left s
        | otherwise =
            let !mid = (s+e) `shiftR` 1
                v = arr `A.indexSmallArray` mid
            in case v' `compare` v of LT -> go s (mid-1)
                                      GT -> go (mid+1) e
                                      _  -> Right mid