{-# LANGUAGE BangPatterns     #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MultiWayIf       #-}
{-# LANGUAGE RecordWildCards  #-}
{-# LANGUAGE TypeFamilies     #-}
module Data.Graph.Dynamic.Internal.Avl
    ( Tree

    , singleton
    , append
    , concat
    , join
    , split
    , root
    , connected
    , label
    , aggregate
    , toList

    -- * Debugging only
    , freeze
    , print
    , assertInvariants
    , assertSingleton
    , assertRoot
    ) where

import           Control.Monad                    (foldM, when)
import           Control.Monad.Primitive          (PrimMonad (..))
import qualified Data.Graph.Dynamic.Internal.Tree as Class
import           Data.List.NonEmpty               (NonEmpty)
import qualified Data.List.NonEmpty               as NonEmpty
import           Data.Monoid                      ((<>))
import           Data.Primitive.MutVar            (MutVar)
import qualified Data.Primitive.MutVar            as MutVar
import qualified Data.Tree                        as Tree
import           Prelude                          hiding (concat, print)

data Tree s a v = Tree
    { tParent :: {-# UNPACK #-} !(MutVar s (Tree s a v))
    , tLeft   :: {-# UNPACK #-} !(MutVar s (Tree s a v))
    , tRight  :: {-# UNPACK #-} !(MutVar s (Tree s a v))
    , tAggs   :: {-# UNPACK #-} !(MutVar s (Aggs v))
    , tLabel  :: !a
    , tValue  :: !v
    }

instance Eq (Tree s a v) where
    -- Reference equality through a MutVar.
    t1 == t2 = tParent t1 == tParent t2

data Aggs v = Aggs
    { aHeight    :: {-# UNPACK #-} !Int
    , aAggregate :: !v
    } deriving (Eq, Show)

emptyAggs :: Monoid v => Aggs v
emptyAggs = Aggs 0 mempty

singletonAggs :: v -> Aggs v
singletonAggs = Aggs 1

joinAggs :: Monoid v => Aggs v -> v -> Aggs v -> Aggs v
joinAggs (Aggs lh la) a (Aggs rh ra) =
    Aggs (max lh rh + 1) (la <> a <> ra)

singleton :: PrimMonad m => a -> v -> m (Tree (PrimState m) a v)
singleton tLabel tValue = do
    tParent <- MutVar.newMutVar undefined
    tLeft   <- MutVar.newMutVar undefined
    tRight  <- MutVar.newMutVar undefined
    tAggs   <- MutVar.newMutVar $ singletonAggs tValue
    let tree = Tree {..}
    MutVar.writeMutVar tParent tree
    MutVar.writeMutVar tLeft   tree
    MutVar.writeMutVar tRight  tree
    return tree

root :: PrimMonad m => Tree (PrimState m) a v -> m (Tree (PrimState m) a v)
root tree@Tree {..} = do
    parent <- MutVar.readMutVar tParent
    if parent == tree then return tree else root parent

concat
    :: (PrimMonad m, Monoid v)
    => NonEmpty (Tree (PrimState m) a v)
    -> m (Tree (PrimState m) a v)
concat (x0 NonEmpty.:| xs0) =
    foldM append x0 xs0

split
    :: (PrimMonad m, Monoid v)
    => Tree (PrimState m) a v
    -> m (Maybe (Tree (PrimState m) a v), Maybe (Tree (PrimState m) a v))
split x0 = do
    (mbL, mbR, p, left) <- cut x0
    if p == x0 then
        return (mbL, mbR)
    else do
        upwards mbL mbR p left
  where
    upwards lacc0 racc0 x left0 = do
        (mbL, mbR, p, left1) <- cut x
        if left0 then do
            racc1 <- join racc0 x mbR
            if p == x then
                return (lacc0, Just racc1)
            else
                upwards lacc0 (Just racc1) p left1
        else do
            lacc1 <- join mbL x lacc0
            if p == x then
                return (Just lacc1, racc0)
            else
                upwards (Just lacc1) racc0 p left1

    cut x = do
        p  <- MutVar.readMutVar (tParent x)
        pl <- MutVar.readMutVar (tLeft p)
        l <- MutVar.readMutVar (tLeft x)
        r <- MutVar.readMutVar (tRight x)
        when (l /= x) $ removeParent l
        when (r /= x) $ removeParent r
        removeParent x
        removeLeft  x
        removeRight x
        updateAggs x
        if pl == x then removeLeft p else removeRight p
        return
            ( if l == x then Nothing else Just l
            , if r == x then Nothing else Just r
            , p
            , pl == x
            )

append
    :: (PrimMonad m, Monoid v)
    => Tree (PrimState m) a v
    -> Tree (PrimState m) a v
    -> m (Tree (PrimState m) a v)
append l0 r0 = do
    -- NOTE: there is a faster way to do this by just following the right spine
    -- and joining along the way.
    rm <- getRightMost l0
    (mbL, mbR) <- split rm
    case mbR of
        Just _ -> error "append: invalid state"
        _      -> assertSingleton rm
    join mbL rm (Just r0)
  where
    getRightMost x = do
        r <- MutVar.readMutVar (tRight x)
        if r == x then return x else getRightMost r

connected
    :: (PrimMonad m, Monoid v)
    => Tree (PrimState m) a v
    -> Tree (PrimState m) a v
    -> m Bool
connected x y = do
    xr <- root x
    yr <- root y
    return $ xr == yr

label :: (PrimMonad m, Monoid v) => Tree (PrimState m) a v -> m a
label = return . tLabel

aggregate :: (PrimMonad m, Monoid v) => Tree (PrimState m) a v -> m v
aggregate = fmap aAggregate . MutVar.readMutVar . tAggs

-- | For debugging/testing.
toList
    :: PrimMonad m => Tree (PrimState m) a v -> m [a]
toList = go []
  where
    go acc0 tree@Tree {..} = do
        left   <- MutVar.readMutVar tLeft
        right  <- MutVar.readMutVar tRight
        acc1   <- if right == tree then return acc0 else go acc0 right
        let acc2 = tLabel : acc1
        if left  == tree then return acc2 else go acc2 left

join
    :: (PrimMonad m, Monoid v)
    => Maybe (Tree (PrimState m) a v)
    -> Tree (PrimState m) a v  -- Must be a singleton
    -> Maybe (Tree (PrimState m) a v)
    -> m (Tree (PrimState m) a v)
join mbL c mbR = do
    lh <- maybe (return 0) (fmap aHeight . MutVar.readMutVar . tAggs) mbL
    rh <- maybe (return 0) (fmap aHeight . MutVar.readMutVar . tAggs) mbR
    if  | lh > rh + 1, Just l <- mbL ->
            joinRight l c mbR
        | rh > lh + 1, Just r <- mbR ->
            joinLeft mbL c r
        | otherwise -> do
            case mbL of Just l -> setLeft  c l; _ -> return ()
            case mbR of Just r -> setRight c r; _ -> return ()
            updateAggs c
            return c

joinLeft
    :: (PrimMonad m, Monoid v)
    => Maybe (Tree (PrimState m) a v)
    -> Tree (PrimState m) a v  -- Must be a singleton
    -> Tree (PrimState m) a v
    -> m (Tree (PrimState m) a v)
joinLeft mbL c r = do
    rl  <- MutVar.readMutVar (tLeft r)
    rla <- leftAggs r rl

    rr  <- MutVar.readMutVar (tRight r)
    rra <- rightAggs r rr

    la  <- maybe (return emptyAggs) (MutVar.readMutVar . tAggs) mbL

    if aHeight rla <= aHeight la + 1 then do
        setLeft r c
        when (rl /= r) $ setRight c rl
        case mbL of Just l -> setLeft c l; _ -> return ()

        let !ca = joinAggs rla (tValue c) la

        -- Invalidity in the parent is fixed with two rotations
        if aHeight rra + 1 < aHeight ca then do
            rotateLeft c rl
            rotateRight r rl

            updateAggs c
            updateAggs r
            updateAggsToRoot rl
        else do
            -- One rotation
            updateAggs c
            updateAggs r
            upLeft r
    else
        joinLeft mbL c rl

upLeft
    :: (PrimMonad m, Monoid v)
    => Tree (PrimState m) a v
    -> m (Tree (PrimState m) a v)
upLeft l = do
    p <- MutVar.readMutVar (tParent l)
    if p == l then
        return l
    else do
        r <- MutVar.readMutVar (tRight p)
        ra <- rightAggs p r
        la <- leftAggs p l
        if aHeight ra + 1 < aHeight la then do
            rotateRight p l
            updateAggs p
            updateAggsToRoot l
        else do
            updateAggs p  -- Stuff below us might have changed.
            upLeft p

joinRight
    :: (PrimMonad m, Monoid v)
    => Tree (PrimState m) a v
    -> Tree (PrimState m) a v  -- Must be a singleton
    -> Maybe (Tree (PrimState m) a v)
    -> m (Tree (PrimState m) a v)
joinRight l c mbR = do
    lr  <- MutVar.readMutVar (tRight l)
    lra <- rightAggs l lr

    ll  <- MutVar.readMutVar (tLeft l)
    lla <- leftAggs l ll

    ra <- maybe (return emptyAggs) (MutVar.readMutVar . tAggs) mbR

    if aHeight lra <= aHeight ra + 1 then do
        setRight l c
        when (lr /= l) $ setLeft c lr
        case mbR of Just r -> setRight c r; _ -> return ()

        let !ca = joinAggs lra (tValue c) ra

        -- Invalidity in the parent is fixed with two rotations
        if aHeight lla + 1 < aHeight ca then do
            rotateRight c lr
            rotateLeft l lr

            -- Many of these are already computed...
            updateAggs l
            updateAggs c
            updateAggsToRoot lr
        else do
            -- One rotation
            updateAggs c
            updateAggs l
            upRight l
    else
        joinRight lr c mbR

upRight
    :: (PrimMonad m, Monoid v)
    => Tree (PrimState m) a v
    -> m (Tree (PrimState m) a v)
upRight r = do
    p <- MutVar.readMutVar (tParent r)
    if p == r then
        return p
    else do
        l <- MutVar.readMutVar (tLeft p)
        la <- leftAggs p l
        ra <- rightAggs p r
        if aHeight la + 1 < aHeight ra then do
            rotateLeft p r
            updateAggs p
            updateAggsToRoot r
        else do
            updateAggs p  -- Stuff below us might have changed.
            upRight p

rotateLeft, rotateRight
    :: PrimMonad m
    => Tree (PrimState m) a v  -- X's parent
    -> Tree (PrimState m) a v  -- X
    -> m ()
rotateLeft p x = do
    b <- MutVar.readMutVar (tLeft x)
    if b == x then removeRight p else setRight p b
    gp <- MutVar.readMutVar (tParent p)
    if gp == p then removeParent x else replace gp p x
    setLeft x p
rotateRight p x = do
    b <- MutVar.readMutVar (tRight x)
    if b == x then removeLeft p else setLeft p b
    gp <- MutVar.readMutVar (tParent p)
    if gp == p then removeParent x else replace gp p x
    setRight x p

setLeft, setRight
    :: PrimMonad m
    => Tree (PrimState m) a v  -- Parent
    -> Tree (PrimState m) a v  -- New child
    -> m ()
setLeft p x = do
    MutVar.writeMutVar (tParent x) p
    MutVar.writeMutVar (tLeft p) x
setRight p x = do
    MutVar.writeMutVar (tParent x) p
    MutVar.writeMutVar (tRight p) x

removeParent, removeLeft, removeRight
    :: PrimMonad m
    => Tree (PrimState m) a v -- Parent
    -> m ()
removeParent x = MutVar.writeMutVar (tParent x) x
removeLeft   x = MutVar.writeMutVar (tLeft x)   x
removeRight  x = MutVar.writeMutVar (tRight x)  x

leftAggs, rightAggs
    :: (PrimMonad m, Monoid v)
    => Tree (PrimState m) a v  -- Parent
    -> Tree (PrimState m) a v  -- Left or right child
    -> m (Aggs v)
leftAggs  p l =
    if p == l then return emptyAggs else MutVar.readMutVar (tAggs l)
rightAggs p r =
    if p == r then return emptyAggs else MutVar.readMutVar (tAggs r)

-- | Replace X by Y in the tree.  X must have a parent.
replace
    :: PrimMonad m
    => Tree (PrimState m) a v  -- ^ X's parent
    -> Tree (PrimState m) a v  -- ^ X
    -> Tree (PrimState m) a v  -- ^ Y
    -> m ()
replace p x y = do
    pl <- MutVar.readMutVar (tLeft p)
    MutVar.writeMutVar (tParent y) p
    if pl == x
        then MutVar.writeMutVar (tLeft p) y
        else MutVar.writeMutVar (tRight p) y

-- | Recompute the aggregate and height of a node.
updateAggs
    :: (Monoid v, PrimMonad m)
    => Tree (PrimState m) a v
    -> m ()
updateAggs t = do
    l  <- MutVar.readMutVar (tLeft t)
    r  <- MutVar.readMutVar (tRight t)
    la <- leftAggs  t l
    ra <- rightAggs t r
    let !agg = joinAggs la (tValue t) ra
    MutVar.writeMutVar (tAggs t) agg

-- | Recompute aggregate and height all the way to the root of the tree.
updateAggsToRoot
    :: (PrimMonad m, Monoid v)
    => Tree (PrimState m) a v
    -> m (Tree (PrimState m) a v)
updateAggsToRoot x = do
    updateAggs x
    p <- MutVar.readMutVar (tParent x)
    if p == x then return x else updateAggsToRoot p

-- | For debugging/testing.
freeze :: PrimMonad m => Tree (PrimState m) a v -> m (Tree.Tree a)
freeze tree@Tree {..} = do
    left  <- MutVar.readMutVar tLeft
    right <- MutVar.readMutVar tRight
    children  <- sequence $
        [freeze left  | left /= tree] ++
        [freeze right | right /= tree]
    return $ Tree.Node tLabel children

print :: Show a => Tree (PrimState IO) a v -> IO ()
print = go 0
  where
    go d t@Tree {..} = do
        left <- MutVar.readMutVar tLeft
        when (left /= t) $ go (d + 1) left

        putStrLn $ replicate d ' ' ++ show tLabel

        right <- MutVar.readMutVar tRight
        when (right /= t) $ go (d + 1) right

assertInvariants
    :: (PrimMonad m, Monoid v, Eq v, Show v) => Tree (PrimState m) a v -> m ()
assertInvariants t = do
    _ <- computeAggs t t
    return ()
  where
    -- TODO: Check average
    computeAggs p x = do
        p' <- MutVar.readMutVar (tParent x)
        when (p /= p') $ fail "broken parent pointer"

        l <- MutVar.readMutVar (tLeft x)
        r <- MutVar.readMutVar (tRight x)
        la <- if l == x then return emptyAggs else computeAggs x l
        ra <- if r == x then return emptyAggs else computeAggs x r

        let actualAggs = joinAggs la (tValue x) ra
        storedAggs <- MutVar.readMutVar (tAggs x)

        when (actualAggs /= storedAggs) $ fail $
            "error in stored aggregates: " ++ show storedAggs ++
            ", actual: " ++ show actualAggs

        when (abs (aHeight la - aHeight ra) > 1) $ fail "inbalanced"
        return actualAggs

assertSingleton :: PrimMonad m => Tree (PrimState m) a v -> m ()
assertSingleton x = do
    l <- MutVar.readMutVar (tLeft x)
    r <- MutVar.readMutVar (tRight x)
    p <- MutVar.readMutVar (tParent x)
    when (l /= x || r /= x || p /= x) $ fail "not a singleton"

assertRoot :: PrimMonad m => Tree (PrimState m) a v -> m ()
assertRoot x = do
    p <- MutVar.readMutVar (tParent x)
    when (p /= x) $ fail "not the root"

data TreeGen s = TreeGen

instance Class.Tree Tree where
    type TreeGen Tree = TreeGen
    newTreeGen _ = return TreeGen

    singleton _ = singleton
    append      = append
    split       = split
    connected   = connected
    root        = root
    label       = label
    aggregate   = aggregate
    toList      = toList

instance Class.TestTree Tree where
    print            = print
    assertInvariants = assertInvariants
    assertSingleton  = assertSingleton
    assertRoot       = assertRoot