{-|
Module      : Z.Data.Vector.FlatIntSet
Description : Fast int set based on sorted vector
Copyright   : (c) Dong Han, 2017-2019
              (c) Tao He, 2018-2019
License     : BSD
Maintainer  : winterland1989@gmail.com
Stability   : experimental
Portability : non-portable

This module provides a simple int set based on sorted vector and binary search. It's particularly
suitable for small sized value collections such as deserializing intermediate representation.
But can also used in various place where insertion and deletion is rare but require fast elem.

-}

module Z.Data.Vector.FlatIntSet
  ( -- * FlatIntSet backed by sorted vector
    FlatIntSet, sortedValues, size, null, empty, map'
  , pack, packN, packR, packRN
  , unpack, unpackR, packVector, packVectorR
  , elem
  , delete
  , insert
  , merge
    -- * search on vectors
  , binarySearch
  ) where

import           Control.DeepSeq
import           Control.Monad
import           Control.Monad.ST
import qualified Data.Semigroup             as Semigroup
import qualified Data.Monoid                as Monoid
import qualified Data.Primitive.PrimArray   as A
import qualified Z.Data.Vector.Base         as V
import qualified Z.Data.Vector.Extra        as V
import qualified Z.Data.Vector.Sort         as V
import qualified Z.Data.Text.Print          as T
import           Data.Bits                   (unsafeShiftR)
import           Data.Data
import           Prelude hiding (elem, null)
import           Test.QuickCheck.Arbitrary (Arbitrary(..), CoArbitrary(..))

--------------------------------------------------------------------------------

newtype FlatIntSet = FlatIntSet { FlatIntSet -> PrimVector Int
sortedValues :: V.PrimVector Int }
    deriving (Int -> FlatIntSet -> ShowS
[FlatIntSet] -> ShowS
FlatIntSet -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [FlatIntSet] -> ShowS
$cshowList :: [FlatIntSet] -> ShowS
show :: FlatIntSet -> String
$cshow :: FlatIntSet -> String
showsPrec :: Int -> FlatIntSet -> ShowS
$cshowsPrec :: Int -> FlatIntSet -> ShowS
Show, FlatIntSet -> FlatIntSet -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: FlatIntSet -> FlatIntSet -> Bool
$c/= :: FlatIntSet -> FlatIntSet -> Bool
== :: FlatIntSet -> FlatIntSet -> Bool
$c== :: FlatIntSet -> FlatIntSet -> Bool
Eq, Eq FlatIntSet
FlatIntSet -> FlatIntSet -> Bool
FlatIntSet -> FlatIntSet -> Ordering
FlatIntSet -> FlatIntSet -> FlatIntSet
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: FlatIntSet -> FlatIntSet -> FlatIntSet
$cmin :: FlatIntSet -> FlatIntSet -> FlatIntSet
max :: FlatIntSet -> FlatIntSet -> FlatIntSet
$cmax :: FlatIntSet -> FlatIntSet -> FlatIntSet
>= :: FlatIntSet -> FlatIntSet -> Bool
$c>= :: FlatIntSet -> FlatIntSet -> Bool
> :: FlatIntSet -> FlatIntSet -> Bool
$c> :: FlatIntSet -> FlatIntSet -> Bool
<= :: FlatIntSet -> FlatIntSet -> Bool
$c<= :: FlatIntSet -> FlatIntSet -> Bool
< :: FlatIntSet -> FlatIntSet -> Bool
$c< :: FlatIntSet -> FlatIntSet -> Bool
compare :: FlatIntSet -> FlatIntSet -> Ordering
$ccompare :: FlatIntSet -> FlatIntSet -> Ordering
Ord, Typeable)

instance T.Print FlatIntSet where
    {-# INLINE toUTF8BuilderP #-}
    toUTF8BuilderP :: Int -> FlatIntSet -> Builder ()
toUTF8BuilderP Int
p (FlatIntSet PrimVector Int
vec) = Bool -> Builder () -> Builder ()
T.parenWhen (Int
p forall a. Ord a => a -> a -> Bool
> Int
10) forall a b. (a -> b) -> a -> b
$ do
        Builder ()
"FlatIntSet{"
        forall (v :: * -> *) a.
Vec v a =>
Builder () -> (a -> Builder ()) -> v a -> Builder ()
T.intercalateVec Builder ()
T.comma forall a. Print a => a -> Builder ()
T.toUTF8Builder PrimVector Int
vec
        Char -> Builder ()
T.char7 Char
'}'

instance Semigroup.Semigroup FlatIntSet where
    {-# INLINE (<>) #-}
    <> :: FlatIntSet -> FlatIntSet -> FlatIntSet
(<>) = FlatIntSet -> FlatIntSet -> FlatIntSet
merge

instance Monoid.Monoid FlatIntSet where
    {-# INLINE mappend #-}
    mappend :: FlatIntSet -> FlatIntSet -> FlatIntSet
mappend = forall a. Semigroup a => a -> a -> a
(<>)
    {-# INLINE mempty #-}
    mempty :: FlatIntSet
mempty = FlatIntSet
empty

instance NFData FlatIntSet where
    {-# INLINE rnf #-}
    rnf :: FlatIntSet -> ()
rnf (FlatIntSet PrimVector Int
vs) = forall a. NFData a => a -> ()
rnf PrimVector Int
vs

instance Arbitrary FlatIntSet where
    arbitrary :: Gen FlatIntSet
arbitrary = [Int] -> FlatIntSet
pack forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Arbitrary a => Gen a
arbitrary
    shrink :: FlatIntSet -> [FlatIntSet]
shrink FlatIntSet
v = [Int] -> FlatIntSet
pack forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Arbitrary a => a -> [a]
shrink (FlatIntSet -> [Int]
unpack FlatIntSet
v)

instance CoArbitrary FlatIntSet where
    coarbitrary :: forall b. FlatIntSet -> Gen b -> Gen b
coarbitrary = forall a b. CoArbitrary a => a -> Gen b -> Gen b
coarbitrary forall b c a. (b -> c) -> (a -> b) -> a -> c
. FlatIntSet -> [Int]
unpack

size :: FlatIntSet -> Int
{-# INLINE size #-}
size :: FlatIntSet -> Int
size = forall (v :: * -> *) a. Vec v a => v a -> Int
V.length forall b c a. (b -> c) -> (a -> b) -> a -> c
. FlatIntSet -> PrimVector Int
sortedValues

null :: FlatIntSet -> Bool
{-# INLINE null #-}
null :: FlatIntSet -> Bool
null = forall (v :: * -> *) a. Vec v a => v a -> Bool
V.null forall b c a. (b -> c) -> (a -> b) -> a -> c
. FlatIntSet -> PrimVector Int
sortedValues

-- | Mapping values of within a set, the result size may change if there're duplicated values
-- after mapping.
map' :: (Int -> Int) -> FlatIntSet -> FlatIntSet
{-# INLINE map' #-}
map' :: (Int -> Int) -> FlatIntSet -> FlatIntSet
map' Int -> Int
f (FlatIntSet PrimVector Int
vs) = PrimVector Int -> FlatIntSet
packVector (forall (u :: * -> *) (v :: * -> *) a b.
(Vec u a, Vec v b) =>
(a -> b) -> u a -> v b
V.map' Int -> Int
f PrimVector Int
vs)

-- | /O(1)/ empty flat set.
empty :: FlatIntSet
{-# NOINLINE empty #-}
empty :: FlatIntSet
empty = PrimVector Int -> FlatIntSet
FlatIntSet forall (v :: * -> *) a. Vec v a => v a
V.empty

-- | /O(N*logN)/ Pack list of values, on duplication prefer left one.
pack :: [Int] -> FlatIntSet
{-# INLINABLE pack #-}
pack :: [Int] -> FlatIntSet
pack [Int]
vs = PrimVector Int -> FlatIntSet
FlatIntSet (forall (v :: * -> *) a. Vec v a => (a -> a -> Bool) -> v a -> v a
V.mergeDupAdjacentLeft forall a. Eq a => a -> a -> Bool
(==) (forall (v :: * -> *) a. (Vec v a, Ord a) => v a -> v a
V.mergeSort (forall (v :: * -> *) a. Vec v a => [a] -> v a
V.pack [Int]
vs)))

-- | /O(N*logN)/ Pack list of values with suggested size, on duplication prefer left one.
packN :: Int -> [Int] -> FlatIntSet
{-# INLINABLE packN #-}
packN :: Int -> [Int] -> FlatIntSet
packN Int
n [Int]
vs = PrimVector Int -> FlatIntSet
FlatIntSet (forall (v :: * -> *) a. Vec v a => (a -> a -> Bool) -> v a -> v a
V.mergeDupAdjacentLeft forall a. Eq a => a -> a -> Bool
(==) (forall (v :: * -> *) a. (Vec v a, Ord a) => v a -> v a
V.mergeSort (forall (v :: * -> *) a. Vec v a => Int -> [a] -> v a
V.packN Int
n [Int]
vs)))

-- | /O(N*logN)/ Pack list of values, on duplication prefer right one.
packR :: [Int] -> FlatIntSet
{-# INLINABLE packR #-}
packR :: [Int] -> FlatIntSet
packR [Int]
vs = PrimVector Int -> FlatIntSet
FlatIntSet (forall (v :: * -> *) a. Vec v a => (a -> a -> Bool) -> v a -> v a
V.mergeDupAdjacentRight forall a. Eq a => a -> a -> Bool
(==) (forall (v :: * -> *) a. (Vec v a, Ord a) => v a -> v a
V.mergeSort (forall (v :: * -> *) a. Vec v a => [a] -> v a
V.pack [Int]
vs)))

-- | /O(N*logN)/ Pack list of values with suggested size, on duplication prefer right one.
packRN :: Int -> [Int] -> FlatIntSet
{-# INLINABLE packRN #-}
packRN :: Int -> [Int] -> FlatIntSet
packRN Int
n [Int]
vs = PrimVector Int -> FlatIntSet
FlatIntSet (forall (v :: * -> *) a. Vec v a => (a -> a -> Bool) -> v a -> v a
V.mergeDupAdjacentRight forall a. Eq a => a -> a -> Bool
(==) (forall (v :: * -> *) a. (Vec v a, Ord a) => v a -> v a
V.mergeSort (forall (v :: * -> *) a. Vec v a => Int -> [a] -> v a
V.packN Int
n [Int]
vs)))

-- | /O(N)/ Unpack a set of values to a list s in ascending order.
--
-- This function works with @foldr/build@ fusion in base.
unpack :: FlatIntSet -> [Int]
{-# INLINE unpack #-}
unpack :: FlatIntSet -> [Int]
unpack = forall (v :: * -> *) a. Vec v a => v a -> [a]
V.unpack forall b c a. (b -> c) -> (a -> b) -> a -> c
. FlatIntSet -> PrimVector Int
sortedValues

-- | /O(N)/ Unpack a set of values to a list s in descending order.
--
-- This function works with @foldr/build@ fusion in base.
unpackR :: FlatIntSet -> [Int]
{-# INLINE unpackR #-}
unpackR :: FlatIntSet -> [Int]
unpackR = forall (v :: * -> *) a. Vec v a => v a -> [a]
V.unpackR forall b c a. (b -> c) -> (a -> b) -> a -> c
. FlatIntSet -> PrimVector Int
sortedValues

-- | /O(N*logN)/ Pack vector of values, on duplication prefer left one.
packVector :: V.PrimVector Int -> FlatIntSet
{-# INLINABLE packVector #-}
packVector :: PrimVector Int -> FlatIntSet
packVector PrimVector Int
vs = PrimVector Int -> FlatIntSet
FlatIntSet (forall (v :: * -> *) a. Vec v a => (a -> a -> Bool) -> v a -> v a
V.mergeDupAdjacentLeft forall a. Eq a => a -> a -> Bool
(==) (forall (v :: * -> *) a. (Vec v a, Ord a) => v a -> v a
V.mergeSort PrimVector Int
vs))

-- | /O(N*logN)/ Pack vector of values, on duplication prefer right one.
packVectorR :: V.PrimVector Int -> FlatIntSet
{-# INLINABLE packVectorR #-}
packVectorR :: PrimVector Int -> FlatIntSet
packVectorR PrimVector Int
vs = PrimVector Int -> FlatIntSet
FlatIntSet (forall (v :: * -> *) a. Vec v a => (a -> a -> Bool) -> v a -> v a
V.mergeDupAdjacentRight forall a. Eq a => a -> a -> Bool
(==) (forall (v :: * -> *) a. (Vec v a, Ord a) => v a -> v a
V.mergeSort PrimVector Int
vs))

-- | /O(logN)/ Binary search on flat set.
elem :: Int -> FlatIntSet -> Bool
{-# INLINABLE elem #-}
elem :: Int -> FlatIntSet -> Bool
elem Int
v (FlatIntSet PrimVector Int
vec) = case PrimVector Int -> Int -> Either Int Int
binarySearch PrimVector Int
vec Int
v of Left Int
_ -> Bool
False
                                                     Either Int Int
_      -> Bool
True

-- | /O(N)/ Insert new value into set.
insert :: Int -> FlatIntSet -> FlatIntSet
{-# INLINABLE insert #-}
insert :: Int -> FlatIntSet -> FlatIntSet
insert Int
v m :: FlatIntSet
m@(FlatIntSet PrimVector Int
vec) =
    case PrimVector Int -> Int -> Either Int Int
binarySearch PrimVector Int
vec Int
v of
        Left Int
i -> PrimVector Int -> FlatIntSet
FlatIntSet (forall (v :: * -> *) a.
(Vec v a, HasCallStack) =>
v a -> Int -> a -> v a
V.unsafeInsertIndex PrimVector Int
vec Int
i Int
v)
        Right Int
_ -> FlatIntSet
m

-- | /O(N)/ Delete a value.
delete :: Int -> FlatIntSet -> FlatIntSet
{-# INLINABLE delete #-}
delete :: Int -> FlatIntSet -> FlatIntSet
delete Int
v m :: FlatIntSet
m@(FlatIntSet PrimVector Int
vec) =
    case PrimVector Int -> Int -> Either Int Int
binarySearch PrimVector Int
vec Int
v of
        Left Int
_ -> FlatIntSet
m
        Right Int
i -> PrimVector Int -> FlatIntSet
FlatIntSet (forall (v :: * -> *) a.
(Vec v a, HasCallStack) =>
v a -> Int -> v a
V.unsafeDeleteIndex PrimVector Int
vec Int
i)

-- | /O(n+m)/ Merge two 'FlatIntSet', prefer right value on value duplication.
merge :: FlatIntSet -> FlatIntSet -> FlatIntSet
{-# INLINABLE merge #-}
merge :: FlatIntSet -> FlatIntSet -> FlatIntSet
merge fmL :: FlatIntSet
fmL@(FlatIntSet (V.PrimVector PrimArray Int
arrL Int
sL Int
lL)) fmR :: FlatIntSet
fmR@(FlatIntSet (V.PrimVector PrimArray Int
arrR Int
sR Int
lR))
    | FlatIntSet -> Bool
null FlatIntSet
fmL = FlatIntSet
fmR
    | FlatIntSet -> Bool
null FlatIntSet
fmR = FlatIntSet
fmL
    | Bool
otherwise = PrimVector Int -> FlatIntSet
FlatIntSet (forall (v :: * -> *) a.
(Vec v a, HasCallStack) =>
Int -> (forall s. MArr (IArray v) s a -> ST s Int) -> v a
V.createN (Int
lLforall a. Num a => a -> a -> a
+Int
lR) (forall s. Int -> Int -> Int -> MutablePrimArray s Int -> ST s Int
go Int
sL Int
sR Int
0))
  where
    endL :: Int
endL = Int
sL forall a. Num a => a -> a -> a
+ Int
lL
    endR :: Int
endR = Int
sR forall a. Num a => a -> a -> a
+ Int
lR
    go :: Int -> Int -> Int -> A.MutablePrimArray s Int -> ST s Int
    go :: forall s. Int -> Int -> Int -> MutablePrimArray s Int -> ST s Int
go !Int
i !Int
j !Int
k MutablePrimArray s Int
marr
        | Int
i forall a. Ord a => a -> a -> Bool
>= Int
endL = do
            forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
MutablePrimArray (PrimState m) a
-> Int -> PrimArray a -> Int -> Int -> m ()
A.copyPrimArray MutablePrimArray s Int
marr Int
k PrimArray Int
arrR Int
j (Int
lRforall a. Num a => a -> a -> a
-Int
j)
            forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$! Int
kforall a. Num a => a -> a -> a
+Int
lRforall a. Num a => a -> a -> a
-Int
j
        | Int
j forall a. Ord a => a -> a -> Bool
>= Int
endR = do
            forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
MutablePrimArray (PrimState m) a
-> Int -> PrimArray a -> Int -> Int -> m ()
A.copyPrimArray MutablePrimArray s Int
marr Int
k PrimArray Int
arrL Int
i (Int
lLforall a. Num a => a -> a -> a
-Int
i)
            forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$! Int
kforall a. Num a => a -> a -> a
+Int
lLforall a. Num a => a -> a -> a
-Int
i
        | Bool
otherwise = do
            let !vL :: Int
vL = PrimArray Int
arrL forall a. Prim a => PrimArray a -> Int -> a
`A.indexPrimArray` Int
i
            let !vR :: Int
vR = PrimArray Int
arrR forall a. Prim a => PrimArray a -> Int -> a
`A.indexPrimArray` Int
j
            case Int
vL forall a. Ord a => a -> a -> Ordering
`compare` Int
vR of Ordering
LT -> do forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
A.writePrimArray MutablePrimArray s Int
marr Int
k Int
vL
                                             forall s. Int -> Int -> Int -> MutablePrimArray s Int -> ST s Int
go (Int
iforall a. Num a => a -> a -> a
+Int
1) Int
j (Int
kforall a. Num a => a -> a -> a
+Int
1) MutablePrimArray s Int
marr
                                    Ordering
EQ -> do forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
A.writePrimArray MutablePrimArray s Int
marr Int
k Int
vR
                                             forall s. Int -> Int -> Int -> MutablePrimArray s Int -> ST s Int
go (Int
iforall a. Num a => a -> a -> a
+Int
1) (Int
jforall a. Num a => a -> a -> a
+Int
1) (Int
kforall a. Num a => a -> a -> a
+Int
1) MutablePrimArray s Int
marr
                                    Ordering
_  -> do forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
A.writePrimArray MutablePrimArray s Int
marr Int
k Int
vR
                                             forall s. Int -> Int -> Int -> MutablePrimArray s Int -> ST s Int
go Int
i (Int
jforall a. Num a => a -> a -> a
+Int
1) (Int
kforall a. Num a => a -> a -> a
+Int
1) MutablePrimArray s Int
marr

--------------------------------------------------------------------------------

-- | Find the value's index in the vector slice, if value exists return 'Right',
-- otherwise 'Left', i.e. the insert index
--
-- This function only works on ascending sorted vectors.
binarySearch :: V.PrimVector Int -> Int -> Either Int Int
{-# INLINABLE binarySearch #-}
binarySearch :: PrimVector Int -> Int -> Either Int Int
binarySearch (V.PrimVector PrimArray Int
_ Int
_ Int
0) Int
_   = forall a b. a -> Either a b
Left Int
0
binarySearch (V.PrimVector PrimArray Int
arr Int
s0 Int
l) !Int
v' = Int -> Int -> Either Int Int
go Int
s0 (Int
s0forall a. Num a => a -> a -> a
+Int
lforall a. Num a => a -> a -> a
-Int
1)
  where
    go :: Int -> Int -> Either Int Int
go !Int
s !Int
e
        | Int
s forall a. Eq a => a -> a -> Bool
== Int
e =
            let v :: Int
v = PrimArray Int
arr forall a. Prim a => PrimArray a -> Int -> a
`A.indexPrimArray` Int
s
            in case Int
v' forall a. Ord a => a -> a -> Ordering
`compare` Int
v of Ordering
LT -> forall a b. a -> Either a b
Left Int
s
                                      Ordering
GT -> let !s' :: Int
s' = Int
sforall a. Num a => a -> a -> a
+Int
1 in forall a b. a -> Either a b
Left Int
s'
                                      Ordering
_  -> forall a b. b -> Either a b
Right Int
s
        | Int
s forall a. Ord a => a -> a -> Bool
>  Int
e = forall a b. a -> Either a b
Left Int
s
        | Bool
otherwise =
            let !mid :: Int
mid = (Int
sforall a. Num a => a -> a -> a
+Int
e) forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
1
                v :: Int
v = PrimArray Int
arr forall a. Prim a => PrimArray a -> Int -> a
`A.indexPrimArray` Int
mid
            in case Int
v' forall a. Ord a => a -> a -> Ordering
`compare` Int
v of Ordering
LT -> Int -> Int -> Either Int Int
go Int
s (Int
midforall a. Num a => a -> a -> a
-Int
1)
                                      Ordering
GT -> Int -> Int -> Either Int Int
go (Int
midforall a. Num a => a -> a -> a
+Int
1) Int
e
                                      Ordering
_  -> forall a b. b -> Either a b
Right Int
mid