{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TypeFamilies #-}
module Data.Graph.Dynamic.Internal.Avl
( Tree
, singleton
, append
, concat
, join
, split
, root
, connected
, label
, aggregate
, toList
, freeze
, print
, assertInvariants
, assertSingleton
, assertRoot
) where
import Control.Monad (foldM, when)
import Control.Monad.Primitive (PrimMonad (..))
import qualified Data.Graph.Dynamic.Internal.Tree as Class
import Data.List.NonEmpty (NonEmpty)
import qualified Data.List.NonEmpty as NonEmpty
import Data.Monoid ((<>))
import Data.Primitive.MutVar (MutVar)
import qualified Data.Primitive.MutVar as MutVar
import qualified Data.Tree as Tree
import Prelude hiding (concat, print)
data Tree s a v = Tree
{ tParent :: {-# UNPACK #-} !(MutVar s (Tree s a v))
, tLeft :: {-# UNPACK #-} !(MutVar s (Tree s a v))
, tRight :: {-# UNPACK #-} !(MutVar s (Tree s a v))
, tAggs :: {-# UNPACK #-} !(MutVar s (Aggs v))
, tLabel :: !a
, tValue :: !v
}
instance Eq (Tree s a v) where
t1 == t2 = tParent t1 == tParent t2
data Aggs v = Aggs
{ aHeight :: {-# UNPACK #-} !Int
, aAggregate :: !v
} deriving (Eq, Show)
emptyAggs :: Monoid v => Aggs v
emptyAggs = Aggs 0 mempty
singletonAggs :: v -> Aggs v
singletonAggs = Aggs 1
joinAggs :: Monoid v => Aggs v -> v -> Aggs v -> Aggs v
joinAggs (Aggs lh la) a (Aggs rh ra) =
Aggs (max lh rh + 1) (la <> a <> ra)
singleton :: PrimMonad m => a -> v -> m (Tree (PrimState m) a v)
singleton tLabel tValue = do
tParent <- MutVar.newMutVar undefined
tLeft <- MutVar.newMutVar undefined
tRight <- MutVar.newMutVar undefined
tAggs <- MutVar.newMutVar $ singletonAggs tValue
let tree = Tree {..}
MutVar.writeMutVar tParent tree
MutVar.writeMutVar tLeft tree
MutVar.writeMutVar tRight tree
return tree
root :: PrimMonad m => Tree (PrimState m) a v -> m (Tree (PrimState m) a v)
root tree@Tree {..} = do
parent <- MutVar.readMutVar tParent
if parent == tree then return tree else root parent
concat
:: (PrimMonad m, Monoid v)
=> NonEmpty (Tree (PrimState m) a v)
-> m (Tree (PrimState m) a v)
concat (x0 NonEmpty.:| xs0) =
foldM append x0 xs0
split
:: (PrimMonad m, Monoid v)
=> Tree (PrimState m) a v
-> m (Maybe (Tree (PrimState m) a v), Maybe (Tree (PrimState m) a v))
split x0 = do
(mbL, mbR, p, left) <- cut x0
if p == x0 then
return (mbL, mbR)
else do
upwards mbL mbR p left
where
upwards lacc0 racc0 x left0 = do
(mbL, mbR, p, left1) <- cut x
if left0 then do
racc1 <- join racc0 x mbR
if p == x then
return (lacc0, Just racc1)
else
upwards lacc0 (Just racc1) p left1
else do
lacc1 <- join mbL x lacc0
if p == x then
return (Just lacc1, racc0)
else
upwards (Just lacc1) racc0 p left1
cut x = do
p <- MutVar.readMutVar (tParent x)
pl <- MutVar.readMutVar (tLeft p)
l <- MutVar.readMutVar (tLeft x)
r <- MutVar.readMutVar (tRight x)
when (l /= x) $ removeParent l
when (r /= x) $ removeParent r
removeParent x
removeLeft x
removeRight x
updateAggs x
if pl == x then removeLeft p else removeRight p
return
( if l == x then Nothing else Just l
, if r == x then Nothing else Just r
, p
, pl == x
)
append
:: (PrimMonad m, Monoid v)
=> Tree (PrimState m) a v
-> Tree (PrimState m) a v
-> m (Tree (PrimState m) a v)
append l0 r0 = do
rm <- getRightMost l0
(mbL, mbR) <- split rm
case mbR of
Just _ -> error "append: invalid state"
_ -> assertSingleton rm
join mbL rm (Just r0)
where
getRightMost x = do
r <- MutVar.readMutVar (tRight x)
if r == x then return x else getRightMost r
connected
:: (PrimMonad m, Monoid v)
=> Tree (PrimState m) a v
-> Tree (PrimState m) a v
-> m Bool
connected x y = do
xr <- root x
yr <- root y
return $ xr == yr
label :: (PrimMonad m, Monoid v) => Tree (PrimState m) a v -> m a
label = return . tLabel
aggregate :: (PrimMonad m, Monoid v) => Tree (PrimState m) a v -> m v
aggregate = fmap aAggregate . MutVar.readMutVar . tAggs
toList
:: PrimMonad m => Tree (PrimState m) a v -> m [a]
toList = go []
where
go acc0 tree@Tree {..} = do
left <- MutVar.readMutVar tLeft
right <- MutVar.readMutVar tRight
acc1 <- if right == tree then return acc0 else go acc0 right
let acc2 = tLabel : acc1
if left == tree then return acc2 else go acc2 left
join
:: (PrimMonad m, Monoid v)
=> Maybe (Tree (PrimState m) a v)
-> Tree (PrimState m) a v
-> Maybe (Tree (PrimState m) a v)
-> m (Tree (PrimState m) a v)
join mbL c mbR = do
lh <- maybe (return 0) (fmap aHeight . MutVar.readMutVar . tAggs) mbL
rh <- maybe (return 0) (fmap aHeight . MutVar.readMutVar . tAggs) mbR
if | lh > rh + 1, Just l <- mbL ->
joinRight l c mbR
| rh > lh + 1, Just r <- mbR ->
joinLeft mbL c r
| otherwise -> do
case mbL of Just l -> setLeft c l; _ -> return ()
case mbR of Just r -> setRight c r; _ -> return ()
updateAggs c
return c
joinLeft
:: (PrimMonad m, Monoid v)
=> Maybe (Tree (PrimState m) a v)
-> Tree (PrimState m) a v
-> Tree (PrimState m) a v
-> m (Tree (PrimState m) a v)
joinLeft mbL c r = do
rl <- MutVar.readMutVar (tLeft r)
rla <- leftAggs r rl
rr <- MutVar.readMutVar (tRight r)
rra <- rightAggs r rr
la <- maybe (return emptyAggs) (MutVar.readMutVar . tAggs) mbL
if aHeight rla <= aHeight la + 1 then do
setLeft r c
when (rl /= r) $ setRight c rl
case mbL of Just l -> setLeft c l; _ -> return ()
let !ca = joinAggs rla (tValue c) la
if aHeight rra + 1 < aHeight ca then do
rotateLeft c rl
rotateRight r rl
updateAggs c
updateAggs r
updateAggsToRoot rl
else do
updateAggs c
updateAggs r
upLeft r
else
joinLeft mbL c rl
upLeft
:: (PrimMonad m, Monoid v)
=> Tree (PrimState m) a v
-> m (Tree (PrimState m) a v)
upLeft l = do
p <- MutVar.readMutVar (tParent l)
if p == l then
return l
else do
r <- MutVar.readMutVar (tRight p)
ra <- rightAggs p r
la <- leftAggs p l
if aHeight ra + 1 < aHeight la then do
rotateRight p l
updateAggs p
updateAggsToRoot l
else do
updateAggs p
upLeft p
joinRight
:: (PrimMonad m, Monoid v)
=> Tree (PrimState m) a v
-> Tree (PrimState m) a v
-> Maybe (Tree (PrimState m) a v)
-> m (Tree (PrimState m) a v)
joinRight l c mbR = do
lr <- MutVar.readMutVar (tRight l)
lra <- rightAggs l lr
ll <- MutVar.readMutVar (tLeft l)
lla <- leftAggs l ll
ra <- maybe (return emptyAggs) (MutVar.readMutVar . tAggs) mbR
if aHeight lra <= aHeight ra + 1 then do
setRight l c
when (lr /= l) $ setLeft c lr
case mbR of Just r -> setRight c r; _ -> return ()
let !ca = joinAggs lra (tValue c) ra
if aHeight lla + 1 < aHeight ca then do
rotateRight c lr
rotateLeft l lr
updateAggs l
updateAggs c
updateAggsToRoot lr
else do
updateAggs c
updateAggs l
upRight l
else
joinRight lr c mbR
upRight
:: (PrimMonad m, Monoid v)
=> Tree (PrimState m) a v
-> m (Tree (PrimState m) a v)
upRight r = do
p <- MutVar.readMutVar (tParent r)
if p == r then
return p
else do
l <- MutVar.readMutVar (tLeft p)
la <- leftAggs p l
ra <- rightAggs p r
if aHeight la + 1 < aHeight ra then do
rotateLeft p r
updateAggs p
updateAggsToRoot r
else do
updateAggs p
upRight p
rotateLeft, rotateRight
:: PrimMonad m
=> Tree (PrimState m) a v
-> Tree (PrimState m) a v
-> m ()
rotateLeft p x = do
b <- MutVar.readMutVar (tLeft x)
if b == x then removeRight p else setRight p b
gp <- MutVar.readMutVar (tParent p)
if gp == p then removeParent x else replace gp p x
setLeft x p
rotateRight p x = do
b <- MutVar.readMutVar (tRight x)
if b == x then removeLeft p else setLeft p b
gp <- MutVar.readMutVar (tParent p)
if gp == p then removeParent x else replace gp p x
setRight x p
setLeft, setRight
:: PrimMonad m
=> Tree (PrimState m) a v
-> Tree (PrimState m) a v
-> m ()
setLeft p x = do
MutVar.writeMutVar (tParent x) p
MutVar.writeMutVar (tLeft p) x
setRight p x = do
MutVar.writeMutVar (tParent x) p
MutVar.writeMutVar (tRight p) x
removeParent, removeLeft, removeRight
:: PrimMonad m
=> Tree (PrimState m) a v
-> m ()
removeParent x = MutVar.writeMutVar (tParent x) x
removeLeft x = MutVar.writeMutVar (tLeft x) x
removeRight x = MutVar.writeMutVar (tRight x) x
leftAggs, rightAggs
:: (PrimMonad m, Monoid v)
=> Tree (PrimState m) a v
-> Tree (PrimState m) a v
-> m (Aggs v)
leftAggs p l =
if p == l then return emptyAggs else MutVar.readMutVar (tAggs l)
rightAggs p r =
if p == r then return emptyAggs else MutVar.readMutVar (tAggs r)
replace
:: PrimMonad m
=> Tree (PrimState m) a v
-> Tree (PrimState m) a v
-> Tree (PrimState m) a v
-> m ()
replace p x y = do
pl <- MutVar.readMutVar (tLeft p)
MutVar.writeMutVar (tParent y) p
if pl == x
then MutVar.writeMutVar (tLeft p) y
else MutVar.writeMutVar (tRight p) y
updateAggs
:: (Monoid v, PrimMonad m)
=> Tree (PrimState m) a v
-> m ()
updateAggs t = do
l <- MutVar.readMutVar (tLeft t)
r <- MutVar.readMutVar (tRight t)
la <- leftAggs t l
ra <- rightAggs t r
let !agg = joinAggs la (tValue t) ra
MutVar.writeMutVar (tAggs t) agg
updateAggsToRoot
:: (PrimMonad m, Monoid v)
=> Tree (PrimState m) a v
-> m (Tree (PrimState m) a v)
updateAggsToRoot x = do
updateAggs x
p <- MutVar.readMutVar (tParent x)
if p == x then return x else updateAggsToRoot p
freeze :: PrimMonad m => Tree (PrimState m) a v -> m (Tree.Tree a)
freeze tree@Tree {..} = do
left <- MutVar.readMutVar tLeft
right <- MutVar.readMutVar tRight
children <- sequence $
[freeze left | left /= tree] ++
[freeze right | right /= tree]
return $ Tree.Node tLabel children
print :: Show a => Tree (PrimState IO) a v -> IO ()
print = go 0
where
go d t@Tree {..} = do
left <- MutVar.readMutVar tLeft
when (left /= t) $ go (d + 1) left
putStrLn $ replicate d ' ' ++ show tLabel
right <- MutVar.readMutVar tRight
when (right /= t) $ go (d + 1) right
assertInvariants
:: (PrimMonad m, Monoid v, Eq v, Show v) => Tree (PrimState m) a v -> m ()
assertInvariants t = do
_ <- computeAggs t t
return ()
where
computeAggs p x = do
p' <- MutVar.readMutVar (tParent x)
when (p /= p') $ fail "broken parent pointer"
l <- MutVar.readMutVar (tLeft x)
r <- MutVar.readMutVar (tRight x)
la <- if l == x then return emptyAggs else computeAggs x l
ra <- if r == x then return emptyAggs else computeAggs x r
let actualAggs = joinAggs la (tValue x) ra
storedAggs <- MutVar.readMutVar (tAggs x)
when (actualAggs /= storedAggs) $ fail $
"error in stored aggregates: " ++ show storedAggs ++
", actual: " ++ show actualAggs
when (abs (aHeight la - aHeight ra) > 1) $ fail "inbalanced"
return actualAggs
assertSingleton :: PrimMonad m => Tree (PrimState m) a v -> m ()
assertSingleton x = do
l <- MutVar.readMutVar (tLeft x)
r <- MutVar.readMutVar (tRight x)
p <- MutVar.readMutVar (tParent x)
when (l /= x || r /= x || p /= x) $ fail "not a singleton"
assertRoot :: PrimMonad m => Tree (PrimState m) a v -> m ()
assertRoot x = do
p <- MutVar.readMutVar (tParent x)
when (p /= x) $ fail "not the root"
data TreeGen s = TreeGen
instance Class.Tree Tree where
type TreeGen Tree = TreeGen
newTreeGen _ = return TreeGen
singleton _ = singleton
append = append
split = split
connected = connected
root = root
label = label
aggregate = aggregate
toList = toList
instance Class.TestTree Tree where
print = print
assertInvariants = assertInvariants
assertSingleton = assertSingleton
assertRoot = assertRoot