{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeSynonymInstances #-}
module Data.Graph.Dynamic.Internal.Splay
( Tree
, singleton
, cons
, snoc
, append
, split
, connected
, root
, aggregate
, toList
, readRoot
, freeze
, print
, assertInvariants
) where
import Control.Monad (when)
import Control.Monad.Primitive (PrimMonad (..))
import qualified Data.Graph.Dynamic.Internal.Tree as Class
import Data.Monoid ((<>))
import Data.Primitive.MutVar (MutVar)
import qualified Data.Primitive.MutVar as MutVar
import qualified Data.Tree as Tree
import Prelude hiding (concat, print)
import System.IO.Unsafe (unsafePerformIO)
import Unsafe.Coerce (unsafeCoerce)
data T s a v = T
{ tParent :: {-# UNPACK #-} !(Tree s a v)
, tLeft :: {-# UNPACK #-} !(Tree s a v)
, tRight :: {-# UNPACK #-} !(Tree s a v)
, tLabel :: !a
, tValue :: !v
, tAgg :: !v
}
nil :: Tree s a v
nil = unsafeCoerce $ unsafePerformIO $ fmap Tree $ MutVar.newMutVar undefined
{-# NOINLINE nil #-}
newtype Tree s a v = Tree {unTree :: MutVar s (T s a v)}
deriving (Eq)
singleton :: PrimMonad m => a -> v -> m (Tree (PrimState m) a v)
singleton tLabel tValue =
fmap Tree $ MutVar.newMutVar $! T nil nil nil tLabel tValue tValue
readRoot :: PrimMonad m => Tree (PrimState m) a v -> m (Tree (PrimState m) a v)
readRoot tree = do
T {..} <- MutVar.readMutVar (unTree tree)
if tParent == nil then return tree else readRoot tParent
cons
:: (PrimMonad m, Monoid v)
=> Tree (PrimState m) a v -> Tree (PrimState m) a v
-> m (Tree (PrimState m) a v)
cons lt@(Tree lv) rt@(Tree rv) = do
r <- MutVar.readMutVar rv
MutVar.modifyMutVar' lv $ \l -> l {tRight = rt, tAgg = tAgg l <> tAgg r}
MutVar.writeMutVar rv $! r {tParent = lt}
return lt
snoc
:: (PrimMonad m, Monoid v)
=> Tree (PrimState m) a v -> Tree (PrimState m) a v
-> m (Tree (PrimState m) a v)
snoc lt@(Tree lv) rt@(Tree rv) = do
l <- MutVar.readMutVar lv
MutVar.modifyMutVar' rv $ \r -> r {tLeft = lt, tAgg = tAgg l <> tAgg r}
MutVar.writeMutVar lv $! l {tParent = rt}
return rt
append
:: (PrimMonad m, Monoid v)
=> Tree (PrimState m) a v
-> Tree (PrimState m) a v
-> m (Tree (PrimState m) a v)
append xt@(Tree _xv) yt@(Tree yv) = do
rmt@(Tree rmv) <- getRightMost xt
_ <- splay rmt
y <- MutVar.readMutVar yv
MutVar.modifyMutVar rmv $ \r -> r {tRight = yt, tAgg = tAgg r <> tAgg y}
MutVar.writeMutVar yv $! y {tParent = rmt}
return rmt
where
getRightMost tt@(Tree tv) = do
t <- MutVar.readMutVar tv
if tRight t == nil then return tt else getRightMost (tRight t)
split
:: (PrimMonad m, Monoid v)
=> Tree (PrimState m) a v
-> m (Maybe (Tree (PrimState m) a v), Maybe (Tree (PrimState m) a v))
split xt@(Tree xv) = do
_ <- splay xt
T {..} <- MutVar.readMutVar xv
when (tLeft /= nil) (removeParent tLeft)
when (tRight /= nil) (removeParent tRight)
MutVar.writeMutVar xv $ T {tAgg = tValue, ..}
removeLeft xt
removeRight xt
return
( if tLeft == nil then Nothing else Just tLeft
, if tRight == nil then Nothing else Just tRight
)
connected
:: (PrimMonad m, Monoid v)
=> Tree (PrimState m) a v
-> Tree (PrimState m) a v
-> m Bool
connected x y = do
_ <- splay x
x' <- splay y
return $ x == x'
root
:: (PrimMonad m, Monoid v)
=> Tree (PrimState m) a v
-> m (Tree (PrimState m) a v)
root x = do
_ <- splay x
return x
label
:: (PrimMonad m, Monoid v)
=> Tree (PrimState m) a v
-> m a
label (Tree xv) = tLabel <$> MutVar.readMutVar xv
aggregate
:: (PrimMonad m, Monoid v)
=> Tree (PrimState m) a v
-> m v
aggregate (Tree xv) = tAgg <$> MutVar.readMutVar xv
toList
:: PrimMonad m => Tree (PrimState m) a v -> m [a]
toList = go []
where
go acc0 (Tree tv) = do
T {..} <- MutVar.readMutVar tv
acc1 <- if tRight == nil then return acc0 else go acc0 tRight
let acc2 = tLabel : acc1
if tLeft == nil then return acc2 else go acc2 tLeft
splay
:: forall m a v. (PrimMonad m, Monoid v)
=> Tree (PrimState m) a v
-> m (Tree (PrimState m) a v)
splay xt@(Tree xv) = do
x0 <- MutVar.readMutVar xv
xla <- if tLeft x0 == nil then return mempty else tAgg <$> MutVar.readMutVar (unTree $ tLeft x0)
xra <- if tRight x0 == nil then return mempty else tAgg <$> MutVar.readMutVar (unTree $ tRight x0)
go xt xla xra x0
where
go :: Tree (PrimState m) a v -> v -> v -> T (PrimState m) a v
-> m (Tree (PrimState m) a v)
go closestToRootFound xla xra !x = do
let !(pt@(Tree pv)) = tParent x
if pt == nil then do
MutVar.writeMutVar xv x
return closestToRootFound
else do
p <- MutVar.readMutVar pv
let gt@(Tree gv) = tParent p
let plt@(Tree plv) = tLeft p
let prt@(Tree prv) = tRight p
let xlt@(Tree xlv) = tLeft x
let xrt@(Tree xrv) = tRight x
if | gt == nil, plt == xt -> do
when (xrt /= nil) $ MutVar.modifyMutVar' xrv $ \xr ->
xr {tParent = pt}
pra <- if prt == nil then return mempty else tAgg <$> MutVar.readMutVar prv
MutVar.writeMutVar pv $! p
{ tLeft = xrt
, tParent = xt
, tAgg = xra <> tValue p <> pra
}
MutVar.writeMutVar xv $! x
{ tAgg = tAgg p
, tRight = pt
, tParent = nil
}
return pt
| gt == nil -> do
when (xlt /= nil) $ MutVar.modifyMutVar' xlv $ \xl ->
xl {tParent = pt}
pla <- if plt == nil then return mempty else tAgg <$> MutVar.readMutVar plv
MutVar.writeMutVar pv $! p
{ tRight = xlt
, tParent = xt
, tAgg = pla <> tValue p <> xla
}
MutVar.writeMutVar xv $! x
{ tAgg = tAgg p
, tLeft = pt
, tParent = nil
}
return pt
| otherwise -> do
g <- MutVar.readMutVar gv
let ggt@(Tree ggv) = tParent g
let glt@(Tree glv) = tLeft g
let grt@(Tree grv) = tRight g
when (ggt /= nil) $ MutVar.modifyMutVar' ggv $ \gg ->
if tLeft gg == gt
then gg {tLeft = xt}
else gg {tRight = xt}
if | plt == xt && glt == pt -> do
pra <- if prt == nil then return mempty else tAgg <$> MutVar.readMutVar prv
gra <- if grt == nil then return mempty else tAgg <$> MutVar.readMutVar grv
let !ga' = pra <> tValue g <> gra
when (prt /= nil) $ MutVar.modifyMutVar' prv $ \pr ->
pr {tParent = gt}
MutVar.writeMutVar gv $! g
{ tParent = pt
, tLeft = prt
, tAgg = ga'
}
when (xrt /= nil) $ MutVar.modifyMutVar' xrv $ \xr ->
xr {tParent = pt}
let !pa' = xra <> tValue p <> ga'
MutVar.writeMutVar pv $! p
{ tParent = xt
, tLeft = xrt
, tRight = gt
, tAgg = pa'
}
go gt xla pa' $! x
{ tRight = pt
, tAgg = tAgg g
, tParent = ggt
}
| plv /= xv && glv /= pv -> do
pla <- if plt == nil then return mempty else tAgg <$> MutVar.readMutVar plv
gla <- if glt == nil then return mempty else tAgg <$> MutVar.readMutVar glv
let !ga' = gla <> tValue g <> pla
when (plt /= nil) $ MutVar.modifyMutVar' plv $ \pl ->
pl {tParent = gt}
MutVar.writeMutVar gv $! g
{ tParent = pt
, tRight = plt
, tAgg = ga'
}
when (xlt /= nil) $ MutVar.modifyMutVar' xlv $ \xl ->
xl {tParent = pt}
let !pa' = ga' <> tValue p <> xla
MutVar.writeMutVar pv $! p
{ tParent = xt
, tLeft = gt
, tRight = xlt
, tAgg = pa'
}
go gt pa' xra $! x
{ tLeft = pt
, tAgg = tAgg g
, tParent = ggt
}
| plv == xv -> do
when (xlt /= nil) $ MutVar.modifyMutVar' xlv $ \xl ->
xl {tParent = gt}
gla <- if glt == nil then return mempty else tAgg <$> MutVar.readMutVar glv
let !ga' = gla <> tValue g <> xla
MutVar.writeMutVar gv $! g
{ tParent = xt
, tRight = xlt
, tAgg = ga'
}
when (xrt /= nil) $ MutVar.modifyMutVar' xrv $ \xr ->
xr {tParent = pt}
pra <- if prt == nil then return mempty else tAgg <$> MutVar.readMutVar prv
let pa' = xra <> tValue p <> pra
MutVar.writeMutVar pv $! p
{ tParent = xt
, tLeft = xrt
, tAgg = pa'
}
go gt ga' pa' $! x
{ tParent = ggt
, tLeft = gt
, tRight = pt
, tAgg = tAgg g
}
| otherwise -> do
when (xrt /= nil) $ MutVar.modifyMutVar' xrv $ \xr ->
xr {tParent = gt}
gra <- if grt == nil then return mempty else tAgg <$> MutVar.readMutVar grv
let !ga' = xra <> tValue g <> gra
MutVar.writeMutVar gv $! g
{ tParent = xt
, tLeft = xrt
, tAgg = ga'
}
when (xlt /= nil) $ MutVar.modifyMutVar' xlv $ \xl ->
xl {tParent = pt}
pla <- if plt == nil then return mempty else tAgg <$> MutVar.readMutVar plv
let !pa' = pla <> tValue p <> xla
MutVar.writeMutVar pv $! p
{ tParent = xt
, tRight = xlt
, tAgg = pa'
}
go gt pa' ga' $! x
{ tParent = ggt
, tLeft = pt
, tRight = gt
, tAgg = tAgg g
}
removeParent, removeLeft, removeRight
:: PrimMonad m
=> Tree (PrimState m) a v
-> m ()
removeParent (Tree x) = MutVar.modifyMutVar' x $ \x' -> x' {tParent = nil}
removeLeft (Tree x) = MutVar.modifyMutVar' x $ \x' -> x' {tLeft = nil}
removeRight (Tree x) = MutVar.modifyMutVar' x $ \x' -> x' {tRight = nil}
freeze :: PrimMonad m => Tree (PrimState m) a v -> m (Tree.Tree a)
freeze (Tree tv) = do
T {..} <- MutVar.readMutVar tv
children <- sequence $
[freeze tLeft | tLeft /= nil] ++
[freeze tRight | tRight /= nil]
return $ Tree.Node tLabel children
print :: Show a => Tree (PrimState IO) a v -> IO ()
print = go 0
where
go d (Tree tv) = do
T {..} <- MutVar.readMutVar tv
when (tLeft /= nil) $ go (d + 1) tLeft
putStrLn $ replicate d ' ' ++ show tLabel
when (tRight /= nil) $ go (d + 1) tRight
assertInvariants
:: (PrimMonad m, Monoid v, Eq v, Show v) => Tree (PrimState m) a v -> m ()
assertInvariants t = do
_ <- computeAgg nil t
return ()
where
computeAgg pt xt@(Tree xv) = do
x' <- MutVar.readMutVar xv
let p' = tParent x'
when (pt /= p') $ fail "broken parent pointer"
let l = tLeft x'
let r = tRight x'
la <- if l == nil then return mempty else computeAgg xt l
ra <- if r == nil then return mempty else computeAgg xt r
let actualAgg = la <> (tValue x') <> ra
let storedAgg = tAgg x'
when (actualAgg /= storedAgg) $ fail $
"error in stored aggregates: " ++ show storedAgg ++
", actual: " ++ show actualAgg
return actualAgg
data TreeGen s = TreeGen
instance Class.Tree Tree where
type TreeGen Tree = TreeGen
newTreeGen _ = return TreeGen
singleton _ = singleton
append = append
split = split
connected = connected
root = root
readRoot = readRoot
label = label
aggregate = aggregate
toList = toList
instance Class.TestTree Tree where
print = print
assertInvariants = assertInvariants
assertSingleton = \_ -> return ()
assertRoot = \_ -> return ()