module RBST.Internal (
Size(..)
, Tree(..)
, RBST(..)
, MonadRandT
, MonadRand
, empty
, emptyWithGen
, one
, oneWithGen
, defaultRandomGenerator
, clockRandomGenerator
, size
, sizeTree
, height
, lookup
, at
, insert
, delete
, remove
, take
, drop
, union
, intersection
, subtraction
, difference
, uniformR
, withTree
) where
import Control.DeepSeq (NFData (..), rnf)
import Control.Monad.Trans.State.Strict (StateT)
import qualified Control.Monad.Trans.State.Strict as State
import Data.Coerce (coerce)
import Data.Foldable (foldl')
import Data.Functor.Identity (Identity)
import Data.Word (Word64)
import GHC.Exts (IsList (..))
import GHC.Generics (Generic)
import Prelude hiding (drop, lookup, take)
import qualified System.Random.Mersenne.Pure64 as Random
newtype Size = Size
{ unSize :: Word64
} deriving stock (Show, Read, Generic)
deriving newtype (Eq, Ord, Num, NFData)
data Tree k a
= Node !Size !k !(Tree k a) !a !(Tree k a)
| Empty
deriving stock (Show, Read, Eq, Generic, Foldable)
deriving anyclass (NFData)
data RBST k a = RBST
{ rbstGen :: !Random.PureMT
, rbstTree :: !(Tree k a)
} deriving stock (Show, Generic, Foldable)
instance Ord k => Semigroup (RBST k a) where
(<>) = union
instance Ord k => Monoid (RBST k a) where
mempty = empty
instance (Eq k, Eq a) => Eq (RBST k a) where
(RBST _ tree1) == (RBST _ tree2) = tree1 == tree2
instance Ord k => IsList (RBST k a) where
type Item (RBST k a) = (k,a)
fromList :: [(k,a)] -> RBST k a
fromList = foldl' ins empty where
ins tree (!k,!x) = insert k x tree
{-# INLINEABLE fromList #-}
toList :: RBST k a -> [(k,a)]
toList RBST{..} = toListTree rbstTree
where
toListTree Empty = []
toListTree (Node _ k l x r) = toListTree l ++ (k,x) : toListTree r
{-# INLINEABLE toList #-}
instance (NFData k, NFData a) => NFData (RBST k a) where
rnf RBST{..} = rnf rbstTree `seq` ()
type MonadRandT m a = StateT Random.PureMT m a
type MonadRand a = StateT Random.PureMT Identity a
defaultRandomGenerator :: Random.PureMT
defaultRandomGenerator = Random.pureMT 0
{-# INLINE defaultRandomGenerator #-}
clockRandomGenerator :: IO Random.PureMT
clockRandomGenerator = Random.newPureMT
{-# INLINE clockRandomGenerator #-}
empty :: RBST k a
empty = emptyWithGen defaultRandomGenerator
{-# INLINE empty #-}
emptyWithGen :: Random.PureMT -> RBST k a
emptyWithGen gen = RBST gen Empty
{-# INLINE emptyWithGen #-}
one :: k -> a -> RBST k a
one = oneWithGen defaultRandomGenerator
{-# INLINE one #-}
oneWithGen :: Random.PureMT -> k -> a -> RBST k a
oneWithGen gen = (RBST gen .) . oneTree
{-# INLINE oneWithGen #-}
oneTree :: k -> a -> Tree k a
oneTree k x = Node 1 k Empty x Empty
{-# INLINE oneTree #-}
size :: RBST k a -> Int
size = withTree sizeTreeInt
{-# INLINE size #-}
sizeTree :: Tree k a -> Size
sizeTree Empty = 0
sizeTree (Node !s _ _ _ _) = s
{-# INLINE sizeTree #-}
sizeTreeInt :: Tree k a -> Int
sizeTreeInt Empty = 0
sizeTreeInt (Node !s _ _ _ _) = fromIntegral (coerce s :: Word64)
{-# INLINE sizeTreeInt #-}
height :: RBST k a -> Int
height = withTree height'
where
height' :: Tree k a -> Int
height' Empty = -1
height' (Node _ _ l _ r) = 1 + max (height' l) (height' r)
{-# INLINEABLE height #-}
lookup :: Ord k => k -> RBST k a -> Maybe a
lookup k1 = withTree lookup'
where
lookup' Empty = Nothing
lookup' (Node _ k2 l a r)
| k1 == k2 = Just a
| k1 < k2 = lookup' l
| otherwise = lookup' r
{-# INLINEABLE lookup #-}
insert :: Ord k => k -> a -> RBST k a -> RBST k a
insert k x RBST{..} = runRand (insert' k x rbstTree) rbstGen
{-# INLINEABLE insert #-}
insert' :: Ord k => k -> a -> Tree k a -> MonadRand (Tree k a)
insert' k x Empty = return (oneTree k x)
insert' k x node@(Node s !k2 l _ r) = do
guess <- uniformR (0, coerce s)
if guess == 0
then do (rep, tree) <- insertRoot k x node
if rep then pushDown tree
else pure tree
else if k < k2
then updateL node <$> insert' k x l
else
updateR node <$> insert' k x r
{-# INLINEABLE insert' #-}
delete :: Ord k => k -> RBST k a -> RBST k a
delete k RBST{..} = runRand (delete' k rbstTree) rbstGen
{-# INLINEABLE delete #-}
delete' :: Ord k => k -> Tree k a -> MonadRand (Tree k a)
delete' _ Empty = return Empty
delete' k node@(Node _ k2 l _ r)
| k == k2 = join l r
| k < k2 = updateL node <$> delete' k l
| otherwise = updateR node <$> delete' k r
{-# INLINEABLE delete' #-}
at :: Int -> RBST k a -> Maybe (k, a)
at ith = withTree (at' ith)
where
at' _ Empty = Nothing
at' i (Node _ k l x r)
| i < sizeL = at' i l
| i == sizeL = Just (k, x)
| otherwise = at' (i - (sizeL + 1)) r
where sizeL = sizeTreeInt l
{-# INLINEABLE at #-}
remove :: Int -> RBST k a -> RBST k a
remove n rbst@RBST{..}
| n < 0 = rbst
| n >= size rbst = rbst
| otherwise = runRand (go n rbstTree) rbstGen
where
go _ Empty = return Empty
go !i node@(Node _ _ l _ r)
| i < sizeL = updateL node <$> (go i l)
| i == sizeL = l `join` r
| otherwise = updateR node <$> (go (i - (sizeL + 1)) r)
where sizeL = sizeTreeInt l
{-# INLINEABLE remove #-}
take :: Int -> RBST k a -> RBST k a
take n rbst@RBST{..}
| n <= 0 = RBST rbstGen Empty
| n >= size rbst = rbst
| otherwise = RBST rbstGen (go n rbstTree)
where
go _ Empty = Empty
go 0 _ = Empty
go i node@(Node _ _ l _ r)
| i < sizeL = go i l
| i == sizeL = l
| otherwise = updateR node (go (i - (sizeL + 1)) r)
where sizeL = sizeTreeInt l
{-# INLINEABLE take #-}
drop :: Int -> RBST k a -> RBST k a
drop n rbst@RBST{..}
| n <= 0 = rbst
| n >= size rbst = RBST rbstGen Empty
| otherwise = RBST rbstGen (go n rbstTree)
where
go _ Empty = Empty
go !0 t = t
go !i node@(Node _ _ l _ r)
| i < sizeL = updateL node (go i l)
| i == sizeL = updateL node Empty
| otherwise = go (i - (sizeL + 1)) r
where sizeL = sizeTreeInt l
{-# INLINEABLE drop #-}
union :: Ord k => RBST k a -> RBST k a -> RBST k a
union (RBST s tree1) (RBST _ tree2) = runRand (union' tree1 tree2) s
where
union' t1 t2 = do
let m = fromIntegral $ sizeTreeInt t1
n = fromIntegral $ sizeTreeInt t2
total = m + n
if total == 0
then return Empty
else do
u <- uniformR (1, total)
let (a,b) = if u <= m then (t1,t2) else (t2,t1)
(Node _ aKey aL x aR) = a
(rep, bL, bR) <- split aKey b
l <- union' aL bL
r <- union' aR bR
let randomize = if rep then pushDown else pure
randomize (recomputeSize (Node 0 aKey l x r))
{-# INLINEABLE union #-}
intersection :: Ord k => RBST k a -> RBST k a -> RBST k a
intersection (RBST s t1) (RBST _ t2) = runRand (intersect' t1 t2) s
where
intersect' Empty _ = return Empty
intersect' (Node _ k l x r) b = do
(rep, bL, bR) <- split k b
iL <- intersect' l bL
iR <- intersect' r bR
if rep then pure $ recomputeSize (Node 0 k iL x iR)
else join iL iR
{-# INLINEABLE intersection #-}
subtraction :: Ord k => RBST k a -> RBST k a -> RBST k a
subtraction (RBST s t1) (RBST _ t2) = runRand (subtraction' t1 t2) s
where
subtraction' Empty _ = return Empty
subtraction' (Node _ k l x r) b = do
(rep, bL, bR) <- split k b
dL <- subtraction' l bL
dR <- subtraction' r bR
if rep then join dL dR
else pure $ recomputeSize (Node 0 k dL x dR)
{-# INLINEABLE subtraction #-}
difference :: Ord k => RBST k a -> RBST k a -> RBST k a
difference (RBST s t1) (RBST _ t2) = runRand (diff t1 t2) s
where
diff Empty b = return b
diff (Node _ k l x r) b = do
(rep, bL, bR) <- split k b
dL <- diff l bL
dR <- diff r bR
if rep then join dL dR
else pure $ recomputeSize (Node 0 k dL x dR)
{-# INLINEABLE difference #-}
uniformR :: (Word64, Word64) -> MonadRand Word64
uniformR (x1, x2)
| n == 0 = error "Check uniformR"
| otherwise = loop
where
(i,j) | x1 < x2 = (x1, x2)
| otherwise = (x2, x1)
n = 1 + (j - i)
buckets = maxBound `div` n
maxN = buckets * n
loop = do
gen <- State.get
let (!x, nextGen) = Random.randomWord64 gen
if x < maxN
then State.put nextGen >> return (i + (x `div` buckets))
else State.put nextGen >> loop
{-# INLINE uniformR #-}
runRand :: MonadRand (Tree k a) -> Random.PureMT -> RBST k a
runRand r s = let (tree, s') = State.runState r s in RBST s' tree
getL :: Tree k a -> Tree k a
getL Empty = Empty
getL (Node _ _ l _ _) = l
{-# INLINE getL #-}
getR :: Tree k a -> Tree k a
getR Empty = Empty
getR (Node _ _ _ _ r) = r
{-# INLINE getR #-}
withTree :: (Tree k a -> r) -> (RBST k a -> r)
withTree f = f . rbstTree
{-# INLINE withTree #-}
recomputeSize :: Tree k a -> Tree k a
recomputeSize Empty = Empty
recomputeSize (Node _ k l c r) =
let !s = sizeTree l + sizeTree r + 1 in Node s k l c r
{-# INLINE recomputeSize #-}
updateL :: Tree k a -> Tree k a -> Tree k a
updateL Empty newL = newL
updateL (Node s k _ c r) newL = recomputeSize (Node s k newL c r)
{-# INLINE updateL #-}
updateR :: Tree k a -> Tree k a -> Tree k a
updateR Empty newR = newR
updateR (Node s k l c _) newR = recomputeSize (Node s k l c newR)
{-# INLINE updateR #-}
insertRoot :: Ord k => k -> a -> Tree k a -> MonadRand (Bool, Tree k a)
insertRoot k x Empty = return (False, oneTree k x)
insertRoot k x tree = do
(rep, l, r) <- split k tree
return (rep, recomputeSize (Node 0 k l x r))
{-# INLINE insertRoot #-}
split :: Ord k => k -> Tree k a -> MonadRand (Bool, Tree k a, Tree k a)
split _ Empty = return (False, Empty, Empty)
split k node@(Node _ k2 l _ r)
| k < k2 = do
(b, t1, t2) <- split k l
return (b, t1, updateL node t2)
| k == k2 = do
(_, t1, t2) <- split k r
newT1 <- join l t1
return (True, newT1, t2)
| otherwise = do
(b, t1, t2) <- split k r
return (b, updateR node t1, t2)
{-# INLINE split #-}
pushDown :: Tree k a -> MonadRand (Tree k a)
pushDown Empty = error "The input of pushDown can be an empty tree."
pushDown tree@(Node _ _ l _ r) = do
let !m = fromIntegral $ sizeTreeInt l
!n = fromIntegral $ sizeTreeInt r
!total = m + n
u <- uniformR (0, total)
if u < m
then updateR l <$> (pushDown $ updateL tree (getR l))
else if u < total
then updateL r <$> (pushDown $ updateR tree (getL r))
else
return tree
join :: Tree k a -> Tree k a -> MonadRand (Tree k a)
join Empty q = return q
join p Empty = return p
join p@(Node s _ _ _ pR) q@(Node s2 _ qL _ _) = do
guess <- uniformR (0, unSize (s + s2))
if guess < unSize s
then updateR p <$> join pR q
else updateL q <$> join p qL
{-# INLINE join #-}